-
Notifications
You must be signed in to change notification settings - Fork 632
[WIP][perf] replace all_reduce for kv_consumer and support different num_tokens among all ranks #4736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces performance optimizations for MoE models in a distributed setting, primarily for kv_consumer nodes. It achieves this by replacing a costly all_reduce operation, allowing different ranks to handle varying numbers of tokens, which should improve throughput. This is supported by pre-calculating a global batch size for MoE operators. While the changes are promising for performance, I've identified a couple of areas that need attention to ensure correctness and robustness. Specifically, there's a behavioral change in how random experts are selected for load balancing that might be a bug, and a critical assumption about the MoE communication method that could lead to runtime failures if not enforced.
| if self.is_kv_consumer and not self.in_profile_run: | ||
| num_tokens_after_padding = torch.tensor([num_tokens] * | ||
| self.dp_size, | ||
| device="cpu", | ||
| dtype=torch.int32) | ||
| return num_tokens, num_tokens_after_padding, with_prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This optimization to skip all_reduce for num_tokens is a great performance improvement. However, it relies on the assumption that the MoE communication method will be MC2, as noted in the comment. If a different communication method like AllGather is used (e.g., if num_tokens exceeds mc2_tokens_capacity), it will lead to a runtime failure because AllGather requires tensors of the same size across ranks.
To make this more robust, I suggest adding an assertion to ensure that MC2 is indeed the selected communication method when this optimization is active. This will prevent silent failures in unexpected scenarios.
| if self.is_kv_consumer and not self.in_profile_run: | |
| num_tokens_after_padding = torch.tensor([num_tokens] * | |
| self.dp_size, | |
| device="cpu", | |
| dtype=torch.int32) | |
| return num_tokens, num_tokens_after_padding, with_prefill | |
| if self.is_kv_consumer and not self.in_profile_run: | |
| assert self._select_moe_comm_method(num_tokens) == MoECommType.MC2, \ | |
| "Skipping all_reduce for num_tokens is only supported with MC2 MoE communication." | |
| num_tokens_after_padding = torch.tensor([num_tokens] * | |
| self.dp_size, | |
| device="cpu", | |
| dtype=torch.int32) | |
| return num_tokens, num_tokens_after_padding, with_prefill |
| random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device) | ||
| topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for enable_force_load_balance has been changed to potentially include redundant experts. The previous implementation explicitly excluded them by using global_num_experts - global_redundant_expert_num as the upper bound for random integers. The new implementation uses global_num_experts for the torch.rand call, which means redundant experts can be selected.
If redundant experts should not receive tokens, this is a bug. If this is the case, please consider the following suggestion to correct the range of experts.
random_matrix = torch.rand(topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)| random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device) | ||
| topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for enable_force_load_balance has been changed to potentially include redundant experts. The previous implementation explicitly excluded them by using global_num_experts - global_redundant_expert_num as the upper bound for random integers. The new implementation uses global_num_experts for the torch.rand call, which means redundant experts can be selected.
If redundant experts should not receive tokens, this is a bug. If this is the case, please consider the following suggestion to correct the range of experts.
random_matrix = torch.rand(topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
f45b53e to
87129a9
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
…mong all ranks Signed-off-by: linfeng-yuan <1102311262@qq.com>
87129a9 to
0381e07
Compare
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?