Skip to content

Conversation

@acisseJZhong
Copy link

@acisseJZhong acisseJZhong commented Dec 5, 2025

Add support to indexing rope cache using position_ids, this might be needed during

  1. inference, where we passed in position_ids into transformer forward
  2. CP load balancing where we need to index rope cache given positions ids

Test:
running dpskv3 16b base
image

also tested in https://github.com/wwwjn/torchtitan/pull/1/files when passing position_ids
image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 5, 2025
@acisseJZhong acisseJZhong reopened this Dec 5, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! May I ask: for the dsv3 16B test on dp4tp2 before vs. after, did you explicitly pass positions into the model? We should try both (1, seq_len) and (batch_size, seq_len) inputs (they could be the trivial 0 -> seq_len - 1 ids).

Also had an inline comment.

"attention": prepare_module_input(
input_layouts=(Shard(1), Replicate(), None),
desired_input_layouts=(Replicate(), Replicate(), None),
input_layouts=(Shard(1), Replicate(), None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that when positions is not None, this is making implicit assumption that positions has the the expected sharding when it's used, namely
sharded on batch dim by DP, replicate on TP mesh, sharded on seq dim by CP

I don't have a good solution right now -- making it Replicate by default will fail here when positions is None https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L521
but clearly this is leaving a footgun. I'd suggest we add a comment for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants