-
Notifications
You must be signed in to change notification settings - Fork 624
Support rope cache indexing using positions #2112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! May I ask: for the dsv3 16B test on dp4tp2 before vs. after, did you explicitly pass positions into the model? We should try both (1, seq_len) and (batch_size, seq_len) inputs (they could be the trivial 0 -> seq_len - 1 ids).
Also had an inline comment.
| "attention": prepare_module_input( | ||
| input_layouts=(Shard(1), Replicate(), None), | ||
| desired_input_layouts=(Replicate(), Replicate(), None), | ||
| input_layouts=(Shard(1), Replicate(), None, None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that when positions is not None, this is making implicit assumption that positions has the the expected sharding when it's used, namely
sharded on batch dim by DP, replicate on TP mesh, sharded on seq dim by CP
I don't have a good solution right now -- making it Replicate by default will fail here when positions is None https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L521
but clearly this is leaving a footgun. I'd suggest we add a comment for now.
Add support to indexing rope cache using
position_ids, this might be needed duringposition_idsinto transformer forwardTest:

running dpskv3 16b base
also tested in https://github.com/wwwjn/torchtitan/pull/1/files when passing position_ids
