Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
import torch

from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine

from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test

ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]

MODELS = [
"hmellor/tiny-random-Gemma2ForCausalLM",
"meta-llama/Llama-3.2-1B-Instruct",
Expand Down Expand Up @@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("backend", ATTN_BACKEND)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("async_scheduling", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
try:
import aiter # noqa: F401

attn_backend_list.append("FLASH_ATTN")
attn_backend_list.append("ROCM_AITER_FA")
except Exception:
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
print("Skip ROCM_AITER_FA on ROCm as aiter is not installed")

return attn_backend_list
elif current_platform.is_xpu():
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ def test_eagle_correctness(
"multi-token eagle spec decode on current platform"
)

if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform")
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
else:
m.setenv("VLLM_ROCM_USE_AITER", "1")

Expand Down
8 changes: 6 additions & 2 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_load_model(
"multi-token eagle spec decode on current platform"
)

if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Setup draft model mock
Expand Down Expand Up @@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
"because it requires special input mocking."
)

if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Use GPU device
Expand Down Expand Up @@ -541,6 +541,10 @@ def create_deterministic_logits(token_ids):
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
elif attn_backend == "ROCM_AITER_FA":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.ROCM_AITER_FA
)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")

Expand Down
2 changes: 1 addition & 1 deletion tests/v1/spec_decode/test_max_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_eagle_max_len(
"multi-token eagle spec decode on current platform"
)

if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")

llm = LLM(
Expand Down