Skip to content

Commit 36aae33

Browse files
committed
[Model Runner V2] Support num NaNs in logits
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 671427e commit 36aae33

File tree

7 files changed

+78
-7
lines changed

7 files changed

+78
-7
lines changed

vllm/v1/worker/gpu/async_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
AsyncModelRunnerOutput,
99
LogprobsTensors,
1010
ModelRunnerOutput,
11-
SamplerOutput,
1211
)
12+
from vllm.v1.worker.gpu.sample.output import SamplerOutput
1313

1414

1515
class AsyncOutput(AsyncModelRunnerOutput):
@@ -54,6 +54,10 @@ def __init__(
5454
)
5555
else:
5656
self.logprobs_tensors = None
57+
if sampler_output.num_nans is not None:
58+
self.num_nans = sampler_output.num_nans.to("cpu", non_blocking=True)
59+
else:
60+
self.num_nans = None
5761
self.num_sampled_tokens_cpu = num_sampled_tokens.to(
5862
"cpu", non_blocking=True
5963
)
@@ -80,6 +84,13 @@ def get_output(self) -> ModelRunnerOutput:
8084
del sampled_token_ids[i][num_sampled_tokens_np[i] :]
8185
self.model_runner_output.sampled_token_ids = sampled_token_ids
8286

87+
if self.num_nans is not None:
88+
num_nans_np = self.num_nans.numpy()
89+
self.model_runner_output.num_nans_in_logits = {
90+
req_id: int(num_nans_np[i])
91+
for i, req_id in enumerate(self.model_runner_output.req_ids)
92+
}
93+
8394
if self.logprobs_tensors is not None:
8495
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
8596
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict

vllm/v1/worker/gpu/metrics/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
from torch._inductor.runtime.triton_helpers import libdevice
5+
6+
from vllm.triton_utils import tl, triton
7+
8+
9+
@triton.jit
10+
def _num_nans_kernel(
11+
logits_ptr,
12+
logits_stride,
13+
num_nans_ptr,
14+
vocab_size,
15+
BLOCK_SIZE: tl.constexpr,
16+
):
17+
req_idx = tl.program_id(0)
18+
num_nans = 0
19+
for i in range(0, vocab_size, BLOCK_SIZE):
20+
block = i + tl.arange(0, BLOCK_SIZE)
21+
mask = block < vocab_size
22+
logits = tl.load(
23+
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
24+
)
25+
logits = logits.to(tl.float32)
26+
is_nan = libdevice.isnan(logits).to(tl.int1)
27+
num_nans += tl.sum(is_nan).to(tl.int32)
28+
tl.store(num_nans_ptr + req_idx, num_nans)
29+
30+
31+
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
32+
num_reqs, vocab_size = logits.shape
33+
BLOCK_SIZE = 8192
34+
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
35+
_num_nans_kernel[(num_reqs,)](
36+
logits,
37+
logits.stride(0),
38+
num_nans,
39+
vocab_size,
40+
BLOCK_SIZE=BLOCK_SIZE,
41+
)
42+
return num_nans

vllm/v1/worker/gpu/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
LogprobsTensors,
2626
ModelRunnerOutput,
2727
)
28-
from vllm.v1.sample.sampler import SamplerOutput
2928
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
3029
from vllm.v1.worker.gpu.attn_utils import (
3130
build_attn_metadata,
@@ -53,6 +52,7 @@
5352
SamplingMetadata,
5453
expand_sampling_metadata,
5554
)
55+
from vllm.v1.worker.gpu.sample.output import SamplerOutput
5656
from vllm.v1.worker.gpu.sample.sampler import Sampler
5757
from vllm.v1.worker.gpu.spec_decode import init_speculator
5858
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample

vllm/v1/worker/gpu/sample/min_p.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ def _min_p_kernel(
3939
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
4040

4141

42-
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor | None) -> None:
43-
if min_p is None:
44-
return
42+
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None:
4543
num_reqs, vocab_size = logits.shape
4644
BLOCK_SIZE = 1024
4745
_min_p_kernel[(num_reqs,)](
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import dataclass
4+
5+
import torch
6+
7+
from vllm.v1.outputs import LogprobsTensors
8+
9+
10+
@dataclass
11+
class SamplerOutput:
12+
sampled_token_ids: torch.Tensor
13+
logprobs_tensors: LogprobsTensors | None
14+
num_nans: torch.Tensor | None

vllm/v1/worker/gpu/sample/sampler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
import torch
55

6+
import vllm.envs as envs
67
from vllm.config.model import LogprobsMode
7-
from vllm.v1.outputs import SamplerOutput
88
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
9+
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
910
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
1011
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
1112
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
1213
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
14+
from vllm.v1.worker.gpu.sample.output import SamplerOutput
1315
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
1416

1517

@@ -21,6 +23,7 @@ def __init__(
2123
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
2224
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
2325
self.logprobs_mode = logprobs_mode
26+
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
2427

2528
def __call__(
2629
self,
@@ -42,13 +45,15 @@ def __call__(
4245
else:
4346
logprobs_tensors = None
4447

48+
num_nans = get_num_nans(processed_logits) if self.compute_nans else None
4549
# These are GPU tensors.
4650
sampler_output = SamplerOutput(
4751
# The sampled tokens are expanded to 2D tensor with shape
4852
# [num_requests, 1], where each row represents one generated
4953
# token per request.
5054
sampled_token_ids=sampled.view(-1, 1),
5155
logprobs_tensors=logprobs_tensors,
56+
num_nans=num_nans,
5257
)
5358
return sampler_output
5459

@@ -63,7 +68,8 @@ def sample(
6368
# Apply penalties and temperature in place.
6469
apply_penalties_and_temperature(logits, sampling_metadata)
6570
# Apply min_p in place.
66-
apply_min_p(logits, sampling_metadata.min_p)
71+
if sampling_metadata.min_p is not None:
72+
apply_min_p(logits, sampling_metadata.min_p)
6773
# Apply top_k and/or top_p. This might return a new tensor.
6874
logits = apply_top_k_top_p(
6975
logits, sampling_metadata.top_k, sampling_metadata.top_p

0 commit comments

Comments
 (0)