Skip to content

Commit f854a81

Browse files
full CG fix
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent fa6e1b9 commit f854a81

File tree

1 file changed

+92
-86
lines changed

1 file changed

+92
-86
lines changed

vllm/v1/attention/backends/mla/flashmla_sparse.py

Lines changed: 92 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from dataclasses import asdict, dataclass
3+
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, ClassVar, Optional
55

66
import numpy as np
@@ -127,60 +127,62 @@ class FlashMLASparseMetadataBF16:
127127

128128

129129
@dataclass
130-
class FlashMLASparseMetadataFP8(FlashMLASparseMetadataBF16):
131-
num_prefills: int = 0
132-
num_decodes: int = 0
133-
num_prefill_tokens: int = 0
134-
num_decode_tokens: int = 0
135-
136-
@dataclass
137-
class DecodeMetadata:
138-
scheduler_metadata: torch.Tensor | None
139-
num_splits: torch.Tensor
140-
dummy_block_table: torch.Tensor
141-
cache_lens: torch.Tensor
142-
decode_query_len: int # needed for reshape in spec decode
143-
130+
class FlashMLASparseMetadata(FlashMLASparseMetadataBF16):
144131
@dataclass
145-
class PrefillMetadata:
146-
# Sequence lengths (context + query) for prefill requests
147-
# Shape: [num_prefill_reqs]
148-
seq_lens: torch.Tensor
149-
150-
# Request ID for each token: -1 for decode tokens, request index
151-
# (0, 1, 2, ...) for prefill tokens.
152-
# Shape: [num_actual_tokens]
153-
request_ids: torch.Tensor
154-
155-
# Workspace start offsets for all prefill requests
156-
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
157-
# 0-indexed within each chunk. Used to map prefill tokens to workspace
158-
# offsets in convert_logical_index_to_physical_index
159-
workspace_starts: torch.Tensor
132+
class FP8KernelMetadata:
133+
@dataclass
134+
class DecodeMetadata:
135+
scheduler_metadata: torch.Tensor | None
136+
num_splits: torch.Tensor
137+
dummy_block_table: torch.Tensor
138+
cache_lens: torch.Tensor
139+
decode_query_len: int # needed for reshape in spec decode
160140

161141
@dataclass
162-
class ChunkMetadata:
163-
"""Metadata for a chunk of prefill requests.
142+
class PrefillMetadata:
143+
# Sequence lengths (context + query) for prefill requests
144+
# Shape: [num_prefill_reqs]
145+
seq_lens: torch.Tensor
164146

165-
Prefill requests may be chunked to fit within the fixed workspace size.
166-
"""
147+
# Request ID for each token: -1 for decode tokens, request index
148+
# (0, 1, 2, ...) for prefill tokens.
149+
# Shape: [num_actual_tokens]
150+
request_ids: torch.Tensor
167151

168-
seq_lens: torch.Tensor
169-
tokens_slice: slice
170-
block_table: torch.Tensor
171-
req_start_idx: int
152+
# Workspace start offsets for all prefill requests
153+
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
154+
# 0-indexed within each chunk. Used to map prefill tokens to workspace
155+
# offsets in convert_logical_index_to_physical_index
172156
workspace_starts: torch.Tensor
173-
chunk_tot_seqlen: int
174157

175-
chunks: list[ChunkMetadata]
158+
@dataclass
159+
class ChunkMetadata:
160+
"""Metadata for a chunk of prefill requests.
176161
177-
decode: DecodeMetadata | None = None
178-
prefill: PrefillMetadata | None = None
162+
Prefill requests may be chunked to fit within the fixed workspace size.
163+
"""
179164

165+
seq_lens: torch.Tensor
166+
tokens_slice: slice
167+
block_table: torch.Tensor
168+
req_start_idx: int
169+
workspace_starts: torch.Tensor
170+
chunk_tot_seqlen: int
180171

181-
FlashMLASparseMetadata = FlashMLASparseMetadataBF16 | FlashMLASparseMetadataFP8
172+
chunks: list[ChunkMetadata]
182173

174+
num_prefills: int = 0
175+
num_decodes: int = 0
176+
num_prefill_tokens: int = 0
177+
num_decode_tokens: int = 0
183178

179+
decode: DecodeMetadata | None = None
180+
prefill: PrefillMetadata | None = None
181+
182+
fp8_extra_metadata: FP8KernelMetadata | None = None
183+
184+
185+
# Kernel with prefill workspace support
184186
@triton.jit
185187
def _convert_req_index_to_global_index_kernel(
186188
req_id_ptr, # int32 [num_tokens]
@@ -380,14 +382,19 @@ def __init__(
380382

381383
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
382384
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
383-
self.topk_tokens_tensor = torch.tensor(
384-
[self.topk_tokens], device=device, dtype=torch.int32
385+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
386+
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
387+
self.topk_tokens_tensor = torch.full(
388+
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
385389
)
386-
self.max_model_len_tensor = torch.tensor(
387-
[self.model_config.max_model_len], device=device, dtype=torch.int32
390+
# Shape: [max_num_seqs], all elements = max_model_len
391+
self.max_model_len_tensor = torch.full(
392+
(max_num_seqs,),
393+
self.model_config.max_model_len,
394+
device=device,
395+
dtype=torch.int32,
388396
)
389397
# this is ignored by `flash_mla_with_kvcache` if indices not None
390-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
391398
self.dummy_block_table = torch.empty(
392399
(max_num_seqs, 1), dtype=torch.int32, device=self.device
393400
)
@@ -407,7 +414,7 @@ def __init__(
407414
dtype=torch.int32,
408415
device=device,
409416
)
410-
# Sized for per-request semantics (one entry per decode + 1)
417+
# Sized for per-request batching (num_decodes + 1)
411418
self.num_splits_buffer = torch.empty(
412419
(max_num_seqs + 1,),
413420
dtype=torch.int32,
@@ -422,8 +429,7 @@ def __init__(
422429
def _build_fp8_extra_metadata(
423430
self,
424431
common_attn_metadata: CommonAttentionMetadata,
425-
bf16_metadata: FlashMLASparseMetadataBF16,
426-
):
432+
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
427433
num_tokens = common_attn_metadata.num_actual_tokens
428434

429435
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
@@ -434,8 +440,8 @@ def _build_fp8_extra_metadata(
434440
)
435441
)
436442

437-
fp8_metadata = FlashMLASparseMetadataFP8(
438-
**asdict(bf16_metadata),
443+
FP8Meta = FlashMLASparseMetadata.FP8KernelMetadata
444+
fp8_metadata = FP8Meta(
439445
num_decodes=num_decodes,
440446
num_prefills=num_prefills,
441447
num_decode_tokens=num_decode_tokens,
@@ -518,7 +524,7 @@ def _build_fp8_extra_metadata(
518524
]
519525

520526
prefill_chunks.append(
521-
FlashMLASparseMetadataFP8.PrefillMetadata.ChunkMetadata(
527+
FP8Meta.PrefillMetadata.ChunkMetadata(
522528
seq_lens=chunk_seq_lens,
523529
tokens_slice=tokens_slice,
524530
block_table=chunk_block_table,
@@ -532,22 +538,17 @@ def _build_fp8_extra_metadata(
532538
prefill_workspace_starts_cpu, non_blocking=True
533539
)
534540

535-
fp8_metadata.prefill = FlashMLASparseMetadataFP8.PrefillMetadata(
541+
fp8_metadata.prefill = FP8Meta.PrefillMetadata(
536542
seq_lens=prefill_seq_lens,
537543
request_ids=prefill_request_id,
538544
workspace_starts=prefill_workspace_starts,
539545
chunks=prefill_chunks,
540546
)
541547

542548
if num_decodes > 0:
543-
# Compute decode_query_len (uniform due to require_uniform=True)
544-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
545-
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
546-
547-
decode_cache_seqlens = common_attn_metadata.seq_lens[:num_decodes]
548549
tile_scheduler_metadata, num_splits = get_mla_metadata(
549-
cache_seqlens=decode_cache_seqlens,
550-
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
550+
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
551+
num_q_tokens_per_head_k=num_decode_tokens * self.num_heads,
551552
topk=self.topk_tokens,
552553
num_heads_q=self.num_heads,
553554
num_heads_k=1,
@@ -560,14 +561,19 @@ def _build_fp8_extra_metadata(
560561
:num_sm_parts
561562
]
562563
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
563-
564+
# num_splits has size [num_decodes + 1]
564565
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
565566
num_splits_view.copy_(num_splits)
566567

567-
fp8_metadata.decode = FlashMLASparseMetadataFP8.DecodeMetadata(
568+
# Compute decode_query_len for spec decode (uniform due to require_uniform)
569+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
570+
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
571+
572+
fp8_metadata.decode = FP8Meta.DecodeMetadata(
568573
scheduler_metadata=tile_scheduler_metadata_buffer,
569574
num_splits=num_splits_view,
570-
cache_lens=decode_cache_seqlens,
575+
# Per-request buffers with constant values for cudagraph compatibility
576+
cache_lens=self.max_model_len_tensor[:num_decodes],
571577
dummy_block_table=self.dummy_block_table[:num_decodes],
572578
decode_query_len=decode_query_len,
573579
)
@@ -593,7 +599,11 @@ def build(
593599
)
594600
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
595601

596-
bf16_metadata = FlashMLASparseMetadataBF16(
602+
fp8_metadata = None
603+
if self.use_fp8_kv_cache:
604+
fp8_metadata = self._build_fp8_extra_metadata(common_attn_metadata)
605+
606+
metadata = FlashMLASparseMetadata(
597607
num_reqs=common_attn_metadata.num_reqs,
598608
max_query_len=common_attn_metadata.max_query_len,
599609
max_seq_len=common_attn_metadata.max_seq_len,
@@ -604,12 +614,10 @@ def build(
604614
req_id_per_token=req_id_per_token,
605615
block_size=self.kv_cache_spec.block_size,
606616
topk_tokens=self.topk_tokens,
617+
fp8_extra_metadata=fp8_metadata,
607618
)
608619

609-
if self.use_fp8_kv_cache:
610-
return self._build_fp8_extra_metadata(common_attn_metadata, bf16_metadata)
611-
else:
612-
return bf16_metadata
620+
return metadata
613621

614622

615623
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
@@ -698,7 +706,7 @@ def _forward_fp8_kv_decode(
698706
q: torch.Tensor,
699707
kv_c_and_k_pe_cache: torch.Tensor,
700708
topk_indices: torch.Tensor,
701-
metadata: FlashMLASparseMetadataFP8.DecodeMetadata,
709+
metadata: FlashMLASparseMetadata.FP8KernelMetadata.DecodeMetadata,
702710
) -> torch.Tensor:
703711
num_decodes = metadata.cache_lens.size(0)
704712

@@ -777,15 +785,13 @@ def forward(
777785

778786
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
779787

788+
fp8_metadata = attn_metadata.fp8_extra_metadata
780789
prefill_request_ids = None
781790
prefill_workspace_starts = None
782791
has_prefill_workspace = False
783-
if (
784-
isinstance(attn_metadata, FlashMLASparseMetadataFP8)
785-
and attn_metadata.prefill is not None
786-
):
787-
prefill_request_ids = attn_metadata.prefill.request_ids
788-
prefill_workspace_starts = attn_metadata.prefill.workspace_starts
792+
if fp8_metadata is not None and fp8_metadata.prefill is not None:
793+
prefill_request_ids = fp8_metadata.prefill.request_ids
794+
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
789795
has_prefill_workspace = True
790796

791797
# Convert per-request indices to global slots (decode) or workspace
@@ -824,17 +830,17 @@ def forward(
824830
q, kv_cache, topk_indices_global, attn_metadata
825831
)
826832
else:
827-
assert isinstance(attn_metadata, FlashMLASparseMetadataFP8)
828-
num_prefill_tokens = attn_metadata.num_prefill_tokens
833+
assert fp8_metadata is not None
834+
num_prefill_tokens = fp8_metadata.num_prefill_tokens
829835
# Pure decode case: direct call without allocation
830836
if num_prefill_tokens == 0:
831-
assert attn_metadata.decode is not None
837+
assert fp8_metadata.decode is not None
832838
attn_out = self._forward_fp8_kv_decode(
833-
q, kv_cache, topk_indices_global, attn_metadata.decode
839+
q, kv_cache, topk_indices_global, fp8_metadata.decode
834840
)
835841
else:
836-
assert attn_metadata.prefill is not None
837-
num_decode_tokens = attn_metadata.num_decode_tokens
842+
assert fp8_metadata.prefill is not None
843+
num_decode_tokens = fp8_metadata.num_decode_tokens
838844

839845
# Mixed or pure prefill: allocate output tensor
840846
attn_out = q.new_empty(
@@ -845,15 +851,15 @@ def forward(
845851

846852
# Fill decode portion if present
847853
if num_decode_tokens > 0:
848-
assert attn_metadata.decode is not None
854+
assert fp8_metadata.decode is not None
849855
attn_out[:num_decode_tokens] = self._forward_fp8_kv_decode(
850856
q[:num_decode_tokens],
851857
kv_cache,
852858
topk_indices_global[:num_decode_tokens],
853-
attn_metadata.decode,
859+
fp8_metadata.decode,
854860
)
855861

856-
for chunk in attn_metadata.prefill.chunks:
862+
for chunk in fp8_metadata.prefill.chunks:
857863
chunk_workspace = self.prefill_bf16_workspace[
858864
: chunk.chunk_tot_seqlen
859865
]

0 commit comments

Comments
 (0)