Skip to content

[Bug][ROCm]: vision_embeddings in transformers inaccurate without math SDP #30167

@AndreasKaratzas

Description

@AndreasKaratzas
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version                : 20.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.1.0 25425 1b0eada6b0ee93e2e694c8c146d23fca90bc11c5)
CMake version                : version 3.31.10
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.9.0a0+git1c57644
Is debug build               : False
CUDA used to build PyTorch   : N/A
ROCM used to build PyTorch   : 7.1.25424-4179531dcd

==============================
      Python Environment
==============================
Python version               : 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] (64-bit runtime)
Python platform              : Linux-5.15.0-161-generic-x86_64-with-glibc2.35

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : Could not collect
CUDA_MODULE_LOADING set to   :
GPU models and configuration :  (gfx942:sramecc+:xnack-)
Nvidia driver version        : Could not collect
cuDNN version                : Could not collect
HIP runtime version          : 7.1.25424
MIOpen runtime version       : 3.5.1
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           52 bits physical, 57 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  256
On-line CPU(s) list:                     0-255
Vendor ID:                               AuthenticAMD
Model name:                              AMD EPYC 9575F 64-Core Processor
CPU family:                              26
Model:                                   2
Thread(s) per core:                      2
Core(s) per socket:                      64
Socket(s):                               2
Stepping:                                1
Frequency boost:                         enabled
CPU max MHz:                             5008.0068
CPU min MHz:                             1500.0000
BogoMIPS:                                6600.02
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monito
r ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pst
ate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_m
bm_local avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclm
ulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d
Virtualization:                          AMD-V
L1d cache:                               6 MiB (128 instances)
L1i cache:                               4 MiB (128 instances)
L2 cache:                                128 MiB (128 instances)
L3 cache:                                512 MiB (16 instances)
NUMA node(s):                            2
NUMA node0 CPU(s):                       0-63,128-191
NUMA node1 CPU(s):                       64-127,192-255
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

==============================
Versions of relevant libraries
==============================
[pip3] conch-triton-kernels==1.2.1
[pip3] efficientnet_pytorch==0.7.1
[pip3] numpy==2.2.6
[pip3] open_clip_torch==2.32.0
[pip3] pytorch-lightning==2.6.0
[pip3] pyzmq==27.1.0
[pip3] segmentation_models_pytorch==0.4.0
[pip3] sentence-transformers==5.1.2
[pip3] terratorch==1.0.2
[pip3] torch==2.9.0a0+git1c57644
[pip3] torchaudio==2.9.0+eaa9e4e
[pip3] torchgeo==0.7.0
[pip3] torchmetrics==1.8.2
[pip3] torchvision==0.23.0a0+824e8c8
[pip3] transformers==4.57.3
[pip3] triton==3.4.0
[pip3] triton_kernels==1.0.0
[conda] Could not collect
==============================
         vLLM Info
==============================
ROCM Version                 : 7.1.25424-4179531dcd
vLLM Version                 : 0.11.2.dev599+ge8dc42f0a.d20251206 (git sha: e8dc42f0a, date: 20251206)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
  ============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7
GPU0   0            15           15           15           15           15           15           15
GPU1   15           0            15           15           15           15           15           15
GPU2   15           15           0            15           15           15           15           15
GPU3   15           15           15           0            15           15           15           15
GPU4   15           15           15           15           0            15           15           15
GPU5   15           15           15           15           15           0            15           15
GPU6   15           15           15           15           15           15           0            15
GPU7   15           15           15           15           15           15           15           0

================================= Hops between two GPUs ==================================
       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7
GPU0   0            1            1            1            1            1            1            1
GPU1   1            0            1            1            1            1            1            1
GPU2   1            1            0            1            1            1            1            1
GPU3   1            1            1            0            1            1            1            1
GPU4   1            1            1            1            0            1            1            1
GPU5   1            1            1            1            1            0            1            1
GPU6   1            1            1            1            1            1            0            1
GPU7   1            1            1            1            1            1            1            0

=============================== Link Type between two GPUs ===============================
       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7
GPU0   0            XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         XGMI
GPU1   XGMI         0            XGMI         XGMI         XGMI         XGMI         XGMI         XGMI
GPU2   XGMI         XGMI         0            XGMI         XGMI         XGMI         XGMI         XGMI
GPU3   XGMI         XGMI         XGMI         0            XGMI         XGMI         XGMI         XGMI
GPU4   XGMI         XGMI         XGMI         XGMI         0            XGMI         XGMI         XGMI
GPU5   XGMI         XGMI         XGMI         XGMI         XGMI         0            XGMI         XGMI
GPU6   XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         0            XGMI
GPU7   XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         0

======================================= Numa Nodes =======================================
GPU[0]          : (Topology) Numa Node: 0
GPU[0]          : (Topology) Numa Affinity: 0
GPU[1]          : (Topology) Numa Node: 0
GPU[1]          : (Topology) Numa Affinity: 0
GPU[2]          : (Topology) Numa Node: 0
GPU[2]          : (Topology) Numa Affinity: 0
GPU[3]          : (Topology) Numa Node: 0
GPU[3]          : (Topology) Numa Affinity: 0
GPU[4]          : (Topology) Numa Node: 1
GPU[4]          : (Topology) Numa Affinity: 1
GPU[5]          : (Topology) Numa Node: 1
GPU[5]          : (Topology) Numa Affinity: 1
GPU[6]          : (Topology) Numa Node: 1
GPU[6]          : (Topology) Numa Affinity: 1
GPU[7]          : (Topology) Numa Node: 1
GPU[7]          : (Topology) Numa Affinity: 1
================================== End of ROCm SMI Log ===================================

==============================
     Environment Variables
==============================
PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1

🐛 Describe the bug

On ROCm platforms, multimodal models using the vLLM Transformers backend produce incorrect/garbage outputs due to accuracy issues with the flash_sdp and mem_efficient_sdp backends in PyTorch's scaled_dot_product_attention when used in vision encoders.

Affected Components

  • File: vllm/model_executor/models/transformers/multimodal.py
  • Method: MultiModalMixin.embed_multimodal()
  • Affected models: Multimodal models using the Transformers backend (e.g., Qwen2.5-VL, potentially others)

Root Cause

When self.model.get_image_features() is called, the vision encoder (e.g., Qwen2_5_VLVisionAttention in HuggingFace Transformers) uses torch.nn.functional.scaled_dot_product_attention. On ROCm, this defaults to flash_sdp or mem_efficient_sdp backends which have numerical accuracy issues, causing the vision encoder to produce incorrect embeddings.

The LLM text decoder is unaffected because vLLM sets _attn_implementation = "vllm" for the text config, but the vision encoder config is not modified and uses the default SDPA backends.

Reproduction

Test file: tests/models/multimodal/generation/test_qwen25_vl_transformers.py

import math
from collections import defaultdict
from pathlib import PosixPath

import pytest
from transformers import (
    AutoModelForImageTextToText,
)

from ....conftest import (
    HfRunner,
    ImageTestAssets,
    VllmRunner,
)
from ....utils import large_gpu_mark
from .vlm_utils import model_utils, runners
from .vlm_utils.case_filtering import get_parametrized_options
from .vlm_utils.types import (
    ExpandableVLMTestArgs,
    VLMTestInfo,
    VLMTestType,
)

COMMON_BROADCAST_SETTINGS = {
    "test_type": VLMTestType.IMAGE,
    "dtype": "half",
    "max_tokens": 5,
    "tensor_parallel_size": 2,
    "hf_model_kwargs": {"device_map": "auto"},
    "image_size_factors": [(0.25, 0.5, 1.0)],
    "distributed_executor_backend": (
        "ray",
        "mp",
    ),
}

VLM_TEST_SETTINGS = {
    "qwen2_5_vl-transformers": VLMTestInfo(
        models=["Qwen/Qwen2.5-VL-3B-Instruct"],
        test_type=VLMTestType.IMAGE,
        prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n",  # noqa: E501
        img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>",
        max_model_len=4096,
        max_num_seqs=2,
        auto_cls=AutoModelForImageTextToText,
        vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
        image_size_factors=[(0.25, 0.2, 0.15)],
        vllm_runner_kwargs={
            "model_impl": "transformers",
        },
        marks=[large_gpu_mark(32)],
    ),
}


def _mark_splits(
    test_settings: dict[str, VLMTestInfo],
    *,
    num_groups: int,
) -> dict[str, VLMTestInfo]:
    name_by_test_info_id = {id(v): k for k, v in test_settings.items()}
    test_infos_by_model = defaultdict[str, list[VLMTestInfo]](list)

    for info in test_settings.values():
        for model in info.models:
            test_infos_by_model[model].append(info)

    models = sorted(test_infos_by_model.keys())
    split_size = math.ceil(len(models) / num_groups)

    new_test_settings = dict[str, VLMTestInfo]()

    for i in range(num_groups):
        models_in_group = models[i * split_size : (i + 1) * split_size]

        for model in models_in_group:
            for info in test_infos_by_model[model]:
                new_marks = (info.marks or []) + [pytest.mark.split(group=i)]
                new_info = info._replace(marks=new_marks)
                new_test_settings[name_by_test_info_id[id(info)]] = new_info

    missing_keys = test_settings.keys() - new_test_settings.keys()
    assert not missing_keys, f"Missing keys: {missing_keys}"

    return new_test_settings


VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)

@pytest.mark.parametrize(
    "model_type,test_case",
    get_parametrized_options(
        VLM_TEST_SETTINGS,
        test_type=VLMTestType.IMAGE,
        create_new_process_for_each_test=False,
    ),
)
def test_single_image_models(
    tmp_path: PosixPath,
    model_type: str,
    test_case: ExpandableVLMTestArgs,
    hf_runner: type[HfRunner],
    vllm_runner: type[VllmRunner],
    image_assets: ImageTestAssets,
):
    model_test_info = VLM_TEST_SETTINGS[model_type]
    runners.run_single_image_test(
        tmp_path=tmp_path,
        model_test_info=model_test_info,
        test_case=test_case,
        hf_runner=hf_runner,
        vllm_runner=vllm_runner,
        image_assets=image_assets,
    )


@pytest.mark.parametrize(
    "model_type,test_case",
    get_parametrized_options(
        VLM_TEST_SETTINGS,
        test_type=VLMTestType.MULTI_IMAGE,
        create_new_process_for_each_test=False,
    ),
)
def test_multi_image_models(
    tmp_path: PosixPath,
    model_type: str,
    test_case: ExpandableVLMTestArgs,
    hf_runner: type[HfRunner],
    vllm_runner: type[VllmRunner],
    image_assets: ImageTestAssets,
):
    model_test_info = VLM_TEST_SETTINGS[model_type]
    runners.run_multi_image_test(
        tmp_path=tmp_path,
        model_test_info=model_test_info,
        test_case=test_case,
        hf_runner=hf_runner,
        vllm_runner=vllm_runner,
        image_assets=image_assets,
    )

Temporary Fix

In embed_multimodal() under vllm/model_executor/models/transformers/multimodal.py, wrap the get_image_features() call with torch.nn.attention.sdpa_kernel to force the MATH backend on ROCm:

if pixel_values is not None:
    if current_platform.is_rocm():
        with torch.nn.attention.sdpa_kernel(
            backends=[torch.nn.attention.SDPBackend.MATH]
        ):
            vision_embeddings = self.model.get_image_features(
                pixel_values, **kwargs
            )
    else:
        vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)

Notes

  • The MATH backend is slower but numerically accurate
  • Setting _attn_implementation = "eager" on vision_config causes OOM if max_num_encoder_tokens is large enough due to full attention matrix materialization
  • The mm_processor_kwargs with reduced max_pixels is needed to avoid OOM with the MATH backend during profiling

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingrocmRelated to AMD ROCm

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions