11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from dataclasses import asdict , dataclass
3+ from dataclasses import dataclass
44from typing import TYPE_CHECKING , ClassVar , Optional
55
66import numpy as np
@@ -127,60 +127,62 @@ class FlashMLASparseMetadataBF16:
127127
128128
129129@dataclass
130- class FlashMLASparseMetadataFP8 (FlashMLASparseMetadataBF16 ):
131- num_prefills : int = 0
132- num_decodes : int = 0
133- num_prefill_tokens : int = 0
134- num_decode_tokens : int = 0
135-
136- @dataclass
137- class DecodeMetadata :
138- scheduler_metadata : torch .Tensor | None
139- num_splits : torch .Tensor
140- dummy_block_table : torch .Tensor
141- cache_lens : torch .Tensor
142- decode_query_len : int # needed for reshape in spec decode
143-
130+ class FlashMLASparseMetadata (FlashMLASparseMetadataBF16 ):
144131 @dataclass
145- class PrefillMetadata :
146- # Sequence lengths (context + query) for prefill requests
147- # Shape: [num_prefill_reqs]
148- seq_lens : torch .Tensor
149-
150- # Request ID for each token: -1 for decode tokens, request index
151- # (0, 1, 2, ...) for prefill tokens.
152- # Shape: [num_actual_tokens]
153- request_ids : torch .Tensor
154-
155- # Workspace start offsets for all prefill requests
156- # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
157- # 0-indexed within each chunk. Used to map prefill tokens to workspace
158- # offsets in convert_logical_index_to_physical_index
159- workspace_starts : torch .Tensor
132+ class FP8KernelMetadata :
133+ @dataclass
134+ class DecodeMetadata :
135+ scheduler_metadata : torch .Tensor | None
136+ num_splits : torch .Tensor
137+ dummy_block_table : torch .Tensor
138+ cache_lens : torch .Tensor
139+ decode_query_len : int # needed for reshape in spec decode
160140
161141 @dataclass
162- class ChunkMetadata :
163- """Metadata for a chunk of prefill requests.
142+ class PrefillMetadata :
143+ # Sequence lengths (context + query) for prefill requests
144+ # Shape: [num_prefill_reqs]
145+ seq_lens : torch .Tensor
164146
165- Prefill requests may be chunked to fit within the fixed workspace size.
166- """
147+ # Request ID for each token: -1 for decode tokens, request index
148+ # (0, 1, 2, ...) for prefill tokens.
149+ # Shape: [num_actual_tokens]
150+ request_ids : torch .Tensor
167151
168- seq_lens : torch . Tensor
169- tokens_slice : slice
170- block_table : torch . Tensor
171- req_start_idx : int
152+ # Workspace start offsets for all prefill requests
153+ # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
154+ # 0-indexed within each chunk. Used to map prefill tokens to workspace
155+ # offsets in convert_logical_index_to_physical_index
172156 workspace_starts : torch .Tensor
173- chunk_tot_seqlen : int
174157
175- chunks : list [ChunkMetadata ]
158+ @dataclass
159+ class ChunkMetadata :
160+ """Metadata for a chunk of prefill requests.
176161
177- decode : DecodeMetadata | None = None
178- prefill : PrefillMetadata | None = None
162+ Prefill requests may be chunked to fit within the fixed workspace size.
163+ """
179164
165+ seq_lens : torch .Tensor
166+ tokens_slice : slice
167+ block_table : torch .Tensor
168+ req_start_idx : int
169+ workspace_starts : torch .Tensor
170+ chunk_tot_seqlen : int
180171
181- FlashMLASparseMetadata = FlashMLASparseMetadataBF16 | FlashMLASparseMetadataFP8
172+ chunks : list [ ChunkMetadata ]
182173
174+ num_prefills : int = 0
175+ num_decodes : int = 0
176+ num_prefill_tokens : int = 0
177+ num_decode_tokens : int = 0
183178
179+ decode : DecodeMetadata | None = None
180+ prefill : PrefillMetadata | None = None
181+
182+ fp8_extra_metadata : FP8KernelMetadata | None = None
183+
184+
185+ # Kernel with prefill workspace support
184186@triton .jit
185187def _convert_req_index_to_global_index_kernel (
186188 req_id_ptr , # int32 [num_tokens]
@@ -380,14 +382,19 @@ def __init__(
380382
381383 self .topk_tokens = vllm_config .model_config .hf_config .index_topk
382384 self .use_fp8_kv_cache = cache_config .cache_dtype == "fp8_ds_mla"
383- self .topk_tokens_tensor = torch .tensor (
384- [self .topk_tokens ], device = device , dtype = torch .int32
385+ max_num_seqs = vllm_config .scheduler_config .max_num_seqs
386+ # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
387+ self .topk_tokens_tensor = torch .full (
388+ (max_num_seqs ,), self .topk_tokens , device = device , dtype = torch .int32
385389 )
386- self .max_model_len_tensor = torch .tensor (
387- [self .model_config .max_model_len ], device = device , dtype = torch .int32
390+ # Shape: [max_num_seqs], all elements = max_model_len
391+ self .max_model_len_tensor = torch .full (
392+ (max_num_seqs ,),
393+ self .model_config .max_model_len ,
394+ device = device ,
395+ dtype = torch .int32 ,
388396 )
389397 # this is ignored by `flash_mla_with_kvcache` if indices not None
390- max_num_seqs = vllm_config .scheduler_config .max_num_seqs
391398 self .dummy_block_table = torch .empty (
392399 (max_num_seqs , 1 ), dtype = torch .int32 , device = self .device
393400 )
@@ -407,7 +414,7 @@ def __init__(
407414 dtype = torch .int32 ,
408415 device = device ,
409416 )
410- # Sized for per-request semantics (one entry per decode + 1)
417+ # Sized for per-request batching (num_decodes + 1)
411418 self .num_splits_buffer = torch .empty (
412419 (max_num_seqs + 1 ,),
413420 dtype = torch .int32 ,
@@ -422,8 +429,7 @@ def __init__(
422429 def _build_fp8_extra_metadata (
423430 self ,
424431 common_attn_metadata : CommonAttentionMetadata ,
425- bf16_metadata : FlashMLASparseMetadataBF16 ,
426- ):
432+ ) -> "FlashMLASparseMetadata.FP8KernelMetadata" :
427433 num_tokens = common_attn_metadata .num_actual_tokens
428434
429435 (num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens ) = (
@@ -434,8 +440,8 @@ def _build_fp8_extra_metadata(
434440 )
435441 )
436442
437- fp8_metadata = FlashMLASparseMetadataFP8 (
438- ** asdict ( bf16_metadata ),
443+ FP8Meta = FlashMLASparseMetadata . FP8KernelMetadata
444+ fp8_metadata = FP8Meta (
439445 num_decodes = num_decodes ,
440446 num_prefills = num_prefills ,
441447 num_decode_tokens = num_decode_tokens ,
@@ -518,7 +524,7 @@ def _build_fp8_extra_metadata(
518524 ]
519525
520526 prefill_chunks .append (
521- FlashMLASparseMetadataFP8 .PrefillMetadata .ChunkMetadata (
527+ FP8Meta .PrefillMetadata .ChunkMetadata (
522528 seq_lens = chunk_seq_lens ,
523529 tokens_slice = tokens_slice ,
524530 block_table = chunk_block_table ,
@@ -532,22 +538,17 @@ def _build_fp8_extra_metadata(
532538 prefill_workspace_starts_cpu , non_blocking = True
533539 )
534540
535- fp8_metadata .prefill = FlashMLASparseMetadataFP8 .PrefillMetadata (
541+ fp8_metadata .prefill = FP8Meta .PrefillMetadata (
536542 seq_lens = prefill_seq_lens ,
537543 request_ids = prefill_request_id ,
538544 workspace_starts = prefill_workspace_starts ,
539545 chunks = prefill_chunks ,
540546 )
541547
542548 if num_decodes > 0 :
543- # Compute decode_query_len (uniform due to require_uniform=True)
544- query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
545- decode_query_len = (query_start_loc_cpu [1 ] - query_start_loc_cpu [0 ]).item ()
546-
547- decode_cache_seqlens = common_attn_metadata .seq_lens [:num_decodes ]
548549 tile_scheduler_metadata , num_splits = get_mla_metadata (
549- cache_seqlens = decode_cache_seqlens ,
550- num_q_tokens_per_head_k = decode_query_len * self .num_heads ,
550+ cache_seqlens = self . topk_tokens_tensor [: num_decodes ] ,
551+ num_q_tokens_per_head_k = num_decode_tokens * self .num_heads ,
551552 topk = self .topk_tokens ,
552553 num_heads_q = self .num_heads ,
553554 num_heads_k = 1 ,
@@ -560,14 +561,19 @@ def _build_fp8_extra_metadata(
560561 :num_sm_parts
561562 ]
562563 tile_scheduler_metadata_buffer .copy_ (tile_scheduler_metadata )
563-
564+ # num_splits has size [num_decodes + 1]
564565 num_splits_view = self .num_splits_buffer [: num_decodes + 1 ]
565566 num_splits_view .copy_ (num_splits )
566567
567- fp8_metadata .decode = FlashMLASparseMetadataFP8 .DecodeMetadata (
568+ # Compute decode_query_len for spec decode (uniform due to require_uniform)
569+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
570+ decode_query_len = (query_start_loc_cpu [1 ] - query_start_loc_cpu [0 ]).item ()
571+
572+ fp8_metadata .decode = FP8Meta .DecodeMetadata (
568573 scheduler_metadata = tile_scheduler_metadata_buffer ,
569574 num_splits = num_splits_view ,
570- cache_lens = decode_cache_seqlens ,
575+ # Per-request buffers with constant values for cudagraph compatibility
576+ cache_lens = self .max_model_len_tensor [:num_decodes ],
571577 dummy_block_table = self .dummy_block_table [:num_decodes ],
572578 decode_query_len = decode_query_len ,
573579 )
@@ -593,7 +599,11 @@ def build(
593599 )
594600 req_id_per_token = self .req_id_per_token_buffer [:num_tokens ]
595601
596- bf16_metadata = FlashMLASparseMetadataBF16 (
602+ fp8_metadata = None
603+ if self .use_fp8_kv_cache :
604+ fp8_metadata = self ._build_fp8_extra_metadata (common_attn_metadata )
605+
606+ metadata = FlashMLASparseMetadata (
597607 num_reqs = common_attn_metadata .num_reqs ,
598608 max_query_len = common_attn_metadata .max_query_len ,
599609 max_seq_len = common_attn_metadata .max_seq_len ,
@@ -604,12 +614,10 @@ def build(
604614 req_id_per_token = req_id_per_token ,
605615 block_size = self .kv_cache_spec .block_size ,
606616 topk_tokens = self .topk_tokens ,
617+ fp8_extra_metadata = fp8_metadata ,
607618 )
608619
609- if self .use_fp8_kv_cache :
610- return self ._build_fp8_extra_metadata (common_attn_metadata , bf16_metadata )
611- else :
612- return bf16_metadata
620+ return metadata
613621
614622
615623class FlashMLASparseImpl (MLACommonBaseImpl [FlashMLASparseMetadata ]):
@@ -698,7 +706,7 @@ def _forward_fp8_kv_decode(
698706 q : torch .Tensor ,
699707 kv_c_and_k_pe_cache : torch .Tensor ,
700708 topk_indices : torch .Tensor ,
701- metadata : FlashMLASparseMetadataFP8 .DecodeMetadata ,
709+ metadata : FlashMLASparseMetadata . FP8KernelMetadata .DecodeMetadata ,
702710 ) -> torch .Tensor :
703711 num_decodes = metadata .cache_lens .size (0 )
704712
@@ -777,15 +785,13 @@ def forward(
777785
778786 use_fp8_cache = self .kv_cache_dtype == "fp8_ds_mla"
779787
788+ fp8_metadata = attn_metadata .fp8_extra_metadata
780789 prefill_request_ids = None
781790 prefill_workspace_starts = None
782791 has_prefill_workspace = False
783- if (
784- isinstance (attn_metadata , FlashMLASparseMetadataFP8 )
785- and attn_metadata .prefill is not None
786- ):
787- prefill_request_ids = attn_metadata .prefill .request_ids
788- prefill_workspace_starts = attn_metadata .prefill .workspace_starts
792+ if fp8_metadata is not None and fp8_metadata .prefill is not None :
793+ prefill_request_ids = fp8_metadata .prefill .request_ids
794+ prefill_workspace_starts = fp8_metadata .prefill .workspace_starts
789795 has_prefill_workspace = True
790796
791797 # Convert per-request indices to global slots (decode) or workspace
@@ -824,17 +830,17 @@ def forward(
824830 q , kv_cache , topk_indices_global , attn_metadata
825831 )
826832 else :
827- assert isinstance ( attn_metadata , FlashMLASparseMetadataFP8 )
828- num_prefill_tokens = attn_metadata .num_prefill_tokens
833+ assert fp8_metadata is not None
834+ num_prefill_tokens = fp8_metadata .num_prefill_tokens
829835 # Pure decode case: direct call without allocation
830836 if num_prefill_tokens == 0 :
831- assert attn_metadata .decode is not None
837+ assert fp8_metadata .decode is not None
832838 attn_out = self ._forward_fp8_kv_decode (
833- q , kv_cache , topk_indices_global , attn_metadata .decode
839+ q , kv_cache , topk_indices_global , fp8_metadata .decode
834840 )
835841 else :
836- assert attn_metadata .prefill is not None
837- num_decode_tokens = attn_metadata .num_decode_tokens
842+ assert fp8_metadata .prefill is not None
843+ num_decode_tokens = fp8_metadata .num_decode_tokens
838844
839845 # Mixed or pure prefill: allocate output tensor
840846 attn_out = q .new_empty (
@@ -845,15 +851,15 @@ def forward(
845851
846852 # Fill decode portion if present
847853 if num_decode_tokens > 0 :
848- assert attn_metadata .decode is not None
854+ assert fp8_metadata .decode is not None
849855 attn_out [:num_decode_tokens ] = self ._forward_fp8_kv_decode (
850856 q [:num_decode_tokens ],
851857 kv_cache ,
852858 topk_indices_global [:num_decode_tokens ],
853- attn_metadata .decode ,
859+ fp8_metadata .decode ,
854860 )
855861
856- for chunk in attn_metadata .prefill .chunks :
862+ for chunk in fp8_metadata .prefill .chunks :
857863 chunk_workspace = self .prefill_bf16_workspace [
858864 : chunk .chunk_tot_seqlen
859865 ]
0 commit comments