Skip to content

Commit 22c6415

Browse files
committed
refactor: Use MergedColumnParallelLinear for Whisper cross-attention kv_proj
Address maintainer feedback: - Replace QKVParallelLinear with MergedColumnParallelLinear for kv_proj in WhisperCrossAttention, enabling LoRA support via existing MergedColumnParallelLinearWithLoRA infrastructure - Update weight loading to use integer shard indices (0, 1) instead of string identifiers ("k", "v") for MergedColumnParallelLinear - Remove redundant embedding_modules and embedding_padding_modules attributes from WhisperForConditionalGeneration - Remove example file (similar to existing multilora_inference.py) - Rollback LoRA layer changes as they are no longer needed - Update tests to reflect new architecture Signed-off-by: daje0601 <englishmt4118@gmail.com>
1 parent ba3826b commit 22c6415

File tree

4 files changed

+55
-222
lines changed

4 files changed

+55
-222
lines changed

examples/offline_inference/whisper_multilora_inference.py

Lines changed: 0 additions & 136 deletions
This file was deleted.

tests/lora/test_whisper_lora.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
66
This module tests:
77
1. WhisperForConditionalGeneration LoRA interface compliance
8-
2. MergedQKVParallelLinearWithLoRA support for KV-only (2-slice) configuration
8+
2. MergedColumnParallelLinearWithLoRA support for KV (2-slice) configuration
99
3. WorkerLoRAManager compatibility with Whisper's max_target_positions
1010
"""
1111

1212
import pytest
1313
import torch
1414

1515
from vllm.lora.layers import (
16-
MergedQKVParallelLinearWithLoRA,
16+
MergedColumnParallelLinearWithLoRA,
1717
)
18-
from vllm.model_executor.layers.linear import QKVParallelLinear
18+
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
1919
from vllm.model_executor.models.whisper import WhisperForConditionalGeneration
2020
from vllm.platforms import current_platform
2121

@@ -36,18 +36,6 @@ def test_supports_lora_attribute(self):
3636
"WhisperForConditionalGeneration should inherit from SupportsLoRA"
3737
)
3838

39-
def test_embedding_modules_defined(self):
40-
"""Verify embedding_modules attribute is defined."""
41-
assert hasattr(WhisperForConditionalGeneration, "embedding_modules")
42-
assert isinstance(WhisperForConditionalGeneration.embedding_modules, dict)
43-
44-
def test_embedding_padding_modules_defined(self):
45-
"""Verify embedding_padding_modules attribute is defined."""
46-
assert hasattr(WhisperForConditionalGeneration, "embedding_padding_modules")
47-
assert isinstance(
48-
WhisperForConditionalGeneration.embedding_padding_modules, list
49-
)
50-
5139
def test_packed_modules_mapping_format(self):
5240
"""Verify packed_modules_mapping has correct format for LoRA."""
5341
mapping = WhisperForConditionalGeneration.packed_modules_mapping
@@ -63,20 +51,18 @@ def test_packed_modules_mapping_format(self):
6351
assert mapping["kv_proj"] == ["k_proj", "v_proj"]
6452

6553

66-
class TestMergedQKVParallelLinearWithLoRAKVOnly:
67-
"""Test MergedQKVParallelLinearWithLoRA with KV-only (2-slice) configuration."""
54+
class TestMergedColumnParallelLinearWithLoRAKVOnly:
55+
"""Test MergedColumnParallelLinearWithLoRA with KV (2-slice) configuration."""
6856

6957
def test_can_replace_layer_accepts_2_modules(self):
70-
"""Verify can_replace_layer accepts 2-module (KV-only) configurations."""
58+
"""Verify can_replace_layer accepts 2-module (KV) configurations."""
7159
from vllm.config.lora import LoRAConfig
7260

73-
# Create a mock QKVParallelLinear layer
74-
# This simulates a KV-only projection (like Whisper's encoder_attn.kv_proj)
75-
linear = QKVParallelLinear(
76-
hidden_size=512,
77-
head_size=64,
78-
total_num_heads=8,
79-
total_num_kv_heads=8,
61+
# Create a MergedColumnParallelLinear layer
62+
# This simulates a KV projection (like Whisper's encoder_attn.kv_proj)
63+
linear = MergedColumnParallelLinear(
64+
input_size=512,
65+
output_sizes=[512, 512], # K and V projections
8066
bias=False,
8167
params_dtype=torch.float16,
8268
)
@@ -88,29 +74,19 @@ def test_can_replace_layer_accepts_2_modules(self):
8874
lora_extra_vocab_size=0,
8975
)
9076

91-
# Test with 2 modules (KV-only, like encoder_attn.kv_proj)
77+
# Test with 2 modules (KV, like encoder_attn.kv_proj)
9278
packed_modules_2 = ["k_proj", "v_proj"]
93-
result_2 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
79+
result_2 = MergedColumnParallelLinearWithLoRA.can_replace_layer(
9480
source_layer=linear,
9581
lora_config=lora_config,
9682
packed_modules_list=packed_modules_2,
9783
model_config=None,
9884
)
99-
assert result_2 is True, "Should accept 2-module (KV-only) configuration"
100-
101-
# Test with 3 modules (QKV, like self_attn.qkv_proj)
102-
packed_modules_3 = ["q_proj", "k_proj", "v_proj"]
103-
result_3 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
104-
source_layer=linear,
105-
lora_config=lora_config,
106-
packed_modules_list=packed_modules_3,
107-
model_config=None,
108-
)
109-
assert result_3 is True, "Should accept 3-module (QKV) configuration"
85+
assert result_2 is True, "Should accept 2-module (KV) configuration"
11086

111-
# Test with 1 module (should be rejected)
112-
packed_modules_1 = ["q_proj"]
113-
result_1 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
87+
# Test with 1 module (should be rejected for MergedColumnParallelLinear)
88+
packed_modules_1 = ["k_proj"]
89+
result_1 = MergedColumnParallelLinearWithLoRA.can_replace_layer(
11490
source_layer=linear,
11591
lora_config=lora_config,
11692
packed_modules_list=packed_modules_1,

vllm/lora/layers/column_parallel_linear.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
356356

357357
def __init__(self, base_layer: QKVParallelLinear) -> None:
358358
super().__init__(base_layer)
359+
# There are three LoRA layer.
360+
self.n_slices = len(self.base_layer.output_sizes)
359361

360362
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
361363
self.kv_proj_shard_size = (
@@ -364,23 +366,16 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
364366
self.q_shard_id = self.tp_rank
365367
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
366368

367-
# Build output_slices and output_ids dynamically to support both
368-
# QKV (3 slices) and KV-only (2 slices) configurations.
369-
# KV-only is used in cross-attention layers (e.g., Whisper encoder_attn).
370-
slices = []
371-
ids = []
372-
if self.q_proj_shard_size > 0:
373-
slices.append(self.q_proj_shard_size)
374-
ids.append(self.q_shard_id)
375-
if self.kv_proj_shard_size > 0:
376-
slices.append(self.kv_proj_shard_size)
377-
ids.append(self.kv_shard_id)
378-
slices.append(self.kv_proj_shard_size)
379-
ids.append(self.kv_shard_id)
380-
381-
self.output_slices = tuple(slices)
382-
self.output_ids = tuple(ids)
383-
self.n_slices = len(self.output_slices)
369+
self.output_slices = (
370+
self.q_proj_shard_size,
371+
self.kv_proj_shard_size,
372+
self.kv_proj_shard_size,
373+
)
374+
self.output_ids = (
375+
self.q_shard_id,
376+
self.kv_shard_id,
377+
self.kv_shard_id,
378+
)
384379

385380
def create_lora_weights(
386381
self,
@@ -403,11 +398,7 @@ def can_replace_layer(
403398
packed_modules_list: list,
404399
model_config: PretrainedConfig | None = None,
405400
) -> bool:
406-
# Support both QKV (3 modules) and KV-only (2 modules) configurations
407-
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) in (
408-
2,
409-
3,
410-
)
401+
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
411402

412403

413404
# These following layers are based on the tensor parallelism strategy given in
@@ -548,18 +539,21 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
548539
def slice_lora_a(
549540
self, lora_a: list[torch.Tensor | None]
550541
) -> list[torch.Tensor | None]:
551-
# NOTE: lora_a contains n_slices subloras, and each sublora could be None.
552-
# n_slices is 3 for QKV and 2 for KV-only configurations.
553-
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(self.n_slices)]
554-
start_idx = [self.tp_rank * shard_size[i] for i in range(self.n_slices)]
555-
result: list[torch.Tensor | None] = []
556-
for i in range(self.n_slices):
557-
lora_a_i = lora_a[i]
558-
if lora_a_i is not None:
559-
result.append(lora_a_i[start_idx[i] : start_idx[i] + shard_size[i], :])
560-
else:
561-
result.append(None)
562-
return result
542+
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
543+
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
544+
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
545+
lora_a = [
546+
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
547+
if lora_a[0] is not None
548+
else None,
549+
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
550+
if lora_a[1] is not None
551+
else None,
552+
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
553+
if lora_a[2] is not None
554+
else None,
555+
]
556+
return lora_a
563557

564558
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
565559
return _mcp_apply(x, bias, self)

vllm/model_executor/models/whisper.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm.model_executor.layers.activation import get_act_fn
2828
from vllm.model_executor.layers.linear import (
2929
ColumnParallelLinear,
30+
MergedColumnParallelLinear,
3031
QKVParallelLinear,
3132
RowParallelLinear,
3233
)
@@ -323,11 +324,12 @@ def _init_qkv(
323324
quant_config=quant_config,
324325
prefix=f"{prefix}.q_proj",
325326
)
326-
self.kv_proj = QKVParallelLinear(
327-
hidden_size=embed_dim,
328-
head_size=self.head_dim,
329-
total_num_heads=0,
330-
total_num_kv_heads=self.total_num_heads,
327+
# Use MergedColumnParallelLinear for K and V projections.
328+
# This enables LoRA support via MergedColumnParallelLinearWithLoRA
329+
# which handles 2-slice configurations.
330+
self.kv_proj = MergedColumnParallelLinear(
331+
input_size=embed_dim,
332+
output_sizes=[embed_dim, embed_dim],
331333
bias=bias,
332334
quant_config=quant_config,
333335
prefix=f"{prefix}.kv_proj",
@@ -631,8 +633,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
631633
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
632634
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
633635
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
634-
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
635-
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
636+
# MergedColumnParallelLinear uses integer indices (0, 1)
637+
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
638+
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
636639
]
637640
params_dict = dict(self.named_parameters())
638641
loaded_params: set[str] = set()
@@ -781,10 +784,6 @@ class WhisperForConditionalGeneration(
781784
nn.Module, SupportsTranscription, SupportsMultiModal, SupportsLoRA
782785
):
783786
# LoRA-specific attributes
784-
embedding_modules = {}
785-
embedding_padding_modules: list[str] = []
786-
787-
merge_by_field_config = True
788787
packed_modules_mapping = {
789788
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
790789
"kv_proj": ["k_proj", "v_proj"],

0 commit comments

Comments
 (0)