Skip to content

Commit c2e4de3

Browse files
cg fix
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent d9a580a commit c2e4de3

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

vllm/v1/attention/backends/mla/flashmla_sparse.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,6 @@ def __init__(
413413
dtype=torch.int32,
414414
device=device,
415415
)
416-
# Per-request cache_seqlens buffer (all set to topk_tokens)
417-
self.decode_cache_seqlens_buffer = torch.full(
418-
(max_num_seqs,), self.topk_tokens, dtype=torch.int32, device=device
419-
)
420416
self.req_id_per_token_buffer = torch.empty(
421417
(vllm_config.scheduler_config.max_num_batched_tokens,),
422418
dtype=torch.int32,
@@ -548,9 +544,7 @@ def _build_fp8_extra_metadata(
548544
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
549545
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
550546

551-
# Per-request cache_seqlens: [topk_tokens] * num_decodes
552-
decode_cache_seqlens = self.decode_cache_seqlens_buffer[:num_decodes]
553-
547+
decode_cache_seqlens = common_attn_metadata.seq_lens[:num_decodes]
554548
tile_scheduler_metadata, num_splits = get_mla_metadata(
555549
cache_seqlens=decode_cache_seqlens,
556550
num_q_tokens_per_head_k=decode_query_len * self.num_heads,

vllm/v1/attention/backends/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,11 +856,15 @@ def split_decodes_and_prefills(
856856
return 0, num_reqs, 0, num_tokens
857857

858858
if require_uniform:
859+
# check if we are in a padded uniform batch; this is used for full-CGs, some
860+
# requests may have a query lenght of 0 but since they are padding its fine
861+
# to treat them as decodes (ensures num_decodes matches the captured size)
862+
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
863+
assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
864+
return num_reqs, 0, num_tokens, 0 # all decodes
859865
is_prefill = query_lens != query_lens[0]
860866
else:
861-
# 0-query len indicates a padded request; leave this at the back
862-
# of the batch with the prefills
863-
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
867+
is_prefill = query_lens > decode_threshold
864868

865869
if not torch.any(is_prefill):
866870
return num_reqs, 0, num_tokens, 0

0 commit comments

Comments
 (0)