Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions tests/lora/test_whisper_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for Whisper Multi-LoRA support.

This module tests:
1. WhisperForConditionalGeneration LoRA interface compliance
2. MergedColumnParallelLinearWithLoRA support for KV (2-slice) configuration
3. WorkerLoRAManager compatibility with Whisper's max_target_positions
"""

import pytest
import torch

from vllm.lora.layers import (
MergedColumnParallelLinearWithLoRA,
)
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.models.whisper import WhisperForConditionalGeneration
from vllm.platforms import current_platform

pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
reason="Backend not supported",
)


class TestWhisperLoRAInterface:
"""Test that WhisperForConditionalGeneration has proper LoRA support."""

def test_supports_lora_attribute(self):
"""Verify that WhisperForConditionalGeneration has SupportsLoRA interface."""
from vllm.model_executor.models.interfaces import SupportsLoRA

assert issubclass(WhisperForConditionalGeneration, SupportsLoRA), (
"WhisperForConditionalGeneration should inherit from SupportsLoRA"
)

def test_packed_modules_mapping_format(self):
"""Verify packed_modules_mapping has correct format for LoRA."""
mapping = WhisperForConditionalGeneration.packed_modules_mapping

# Should have qkv_proj and kv_proj mappings
assert "qkv_proj" in mapping, "Missing qkv_proj in packed_modules_mapping"
assert "kv_proj" in mapping, "Missing kv_proj in packed_modules_mapping"

# qkv_proj should map to [q_proj, k_proj, v_proj]
assert mapping["qkv_proj"] == ["q_proj", "k_proj", "v_proj"]

# kv_proj should map to [k_proj, v_proj] (for cross-attention)
assert mapping["kv_proj"] == ["k_proj", "v_proj"]


class TestMergedColumnParallelLinearWithLoRAKVOnly:
"""Test MergedColumnParallelLinearWithLoRA with KV (2-slice) configuration."""

def test_can_replace_layer_accepts_2_modules(self):
"""Verify can_replace_layer accepts 2-module (KV) configurations."""
from vllm.config.lora import LoRAConfig

# Create a MergedColumnParallelLinear layer
# This simulates a KV projection (like Whisper's encoder_attn.kv_proj)
linear = MergedColumnParallelLinear(
input_size=512,
output_sizes=[512, 512], # K and V projections
bias=False,
params_dtype=torch.float16,
)

lora_config = LoRAConfig(
max_lora_rank=32,
max_loras=4,
max_cpu_loras=4,
lora_extra_vocab_size=0,
)

# Test with 2 modules (KV, like encoder_attn.kv_proj)
packed_modules_2 = ["k_proj", "v_proj"]
result_2 = MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=linear,
lora_config=lora_config,
packed_modules_list=packed_modules_2,
model_config=None,
)
assert result_2 is True, "Should accept 2-module (KV) configuration"

# Test with 1 module (should be rejected for MergedColumnParallelLinear)
packed_modules_1 = ["k_proj"]
result_1 = MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=linear,
lora_config=lora_config,
packed_modules_list=packed_modules_1,
model_config=None,
)
assert result_1 is False, "Should reject 1-module configuration"


class TestWorkerLoRAManagerWhisperCompat:
"""Test WorkerLoRAManager compatibility with Whisper config."""

def test_max_position_embeddings_fallback(self):
"""Test that max_target_positions is used when missing."""

# Create a mock config similar to Whisper's
class MockWhisperConfig:
def __init__(self):
self.max_target_positions = 448
# Note: no max_position_embeddings attribute

def get_text_config(self):
return self

config = MockWhisperConfig()

# Simulate the logic from WorkerLoRAManager
max_pos = getattr(
config,
"max_position_embeddings",
getattr(config, "max_target_positions", None),
)

assert max_pos == 448, "Should fall back to max_target_positions"

def test_max_position_embeddings_priority(self):
"""Test that max_position_embeddings takes priority when present."""

class MockLLMConfig:
def __init__(self):
self.max_position_embeddings = 4096
self.max_target_positions = 448

def get_text_config(self):
return self

config = MockLLMConfig()

# Simulate the logic from WorkerLoRAManager
max_pos = getattr(
config,
"max_position_embeddings",
getattr(config, "max_target_positions", None),
)

assert max_pos == 4096, "Should use max_position_embeddings when present"
7 changes: 6 additions & 1 deletion vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def __init__(
# Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config()

self.max_position_embeddings = text_config.max_position_embeddings
# Whisper uses max_target_positions instead of max_position_embeddings
self.max_position_embeddings = getattr(
text_config,
"max_position_embeddings",
getattr(text_config, "max_target_positions", None),
)
self.device = device
# Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager
Expand Down
35 changes: 20 additions & 15 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
Expand All @@ -53,7 +54,12 @@
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsTranscription,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
Expand Down Expand Up @@ -318,11 +324,12 @@ def _init_qkv(
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=0,
total_num_kv_heads=self.total_num_heads,
# Use MergedColumnParallelLinear for K and V projections.
# This enables LoRA support via MergedColumnParallelLinearWithLoRA
# which handles 2-slice configurations.
self.kv_proj = MergedColumnParallelLinear(
input_size=embed_dim,
output_sizes=[embed_dim, embed_dim],
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.kv_proj",
Expand Down Expand Up @@ -626,8 +633,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
# MergedColumnParallelLinear uses integer indices (0, 1)
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
Expand Down Expand Up @@ -773,15 +781,12 @@ def _get_prompt_updates(
dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
nn.Module, SupportsTranscription, SupportsMultiModal, SupportsLoRA
):
# LoRA-specific attributes
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"kv_proj": ["k_proj", "v_proj"],
}

hf_to_vllm_mapper = WeightsMapper(
Expand Down