Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class GenerationConfig:
# router replay
return_routed_experts: bool = False

# ngram
repetition_ngram_size: int = 0
repetition_ngram_threshold: int = 0
repetition_ngram_window_size: int = 1024

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
stop_token_ids/bad_token_ids."""
Expand Down
192 changes: 179 additions & 13 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple
from functools import lru_cache
from typing import Any, Dict, List, Tuple

import numpy as np
import torch

from lmdeploy.messages import LogitsProcessor
Expand All @@ -29,7 +31,7 @@ def _process_bad_words_(scores: torch.Tensor,
return scores


def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor):
def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor):
"""Process repetition penalty."""
score = torch.gather(scores, 1, input_ids)
penalty = penalty.to(score.dtype)
Expand Down Expand Up @@ -68,6 +70,142 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value:
return scores


@lru_cache
def _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1):
return torch.ones(fill, dtype=dtype, device=device)


def ngram(
token_ids: torch.Tensor,
n: torch.Tensor | None,
threshold: torch.Tensor,
max_n: int,
window_size: torch.Tensor | None,
max_window_size: int,
):
"""Compute n-gram matches between sliding windows and a target sequence.

For each batch, performs cosine similarity checking between:
- All sliding windows of length `max_n` from the full sequence
- The last `max_n` tokens of the sequence (target window)

A match is counted when both:
1. Cosine similarity ≈ 1 (normalized vectors match)
2. Vector lengths match (preventing zero/normalization artifacts)

Parameters
----------
token_ids : torch.Tensor
Input token IDs of shape (batch_size, seq_len).
Values are typically ≥0 (0 may represent padding/special tokens).
n : torch.Tensor
Effective n-gram length for each batch element, shape (batch_size,).
When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked.
threshold : torch.Tensor
Minimum number of matching windows required for validity, shape (batch_size,).
max_n : int
Maximum n-gram length (window size for matching).
window_size: torch.Tensor | None
Effective window size for each batch element, shape (batch_size,).
When `same_n=False`, only the last `window_size` tokens are considered for matching.
max_window_size: int
Maximum window size for matching.

Returns
-------
matched_mask : torch.Tensor
Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating
which sliding windows match the target n-gram.
found : torch.Tensor
Boolean tensor of shape (batch_size,) indicating whether each batch
element has at least `threshold` matches.
"""

batch_size, seq_len = token_ids.size()
if seq_len < max_n:
# Not enough tokens to form a single n-gram
matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device)
found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device)
return matched_mask, found
# token_ids could be 0, so we add 1 to avoid div 0
token_ids = token_ids.to(torch.float32).log2() + 1

# Trim to max_window_size
if seq_len >= max_window_size:
token_ids = token_ids[:, -max_window_size:]
max_window_size = token_ids.size(1)

same_window = window_size is None
if not same_window:
# fill -1 for window_size < max_window_size
mask = torch.arange(max_window_size,
device=token_ids.device).unsqueeze(0) >= (max_window_size - window_size.unsqueeze(1))
token_ids.masked_fill_(~mask, -1)

# normalize ids
same_n = n is None
norm = token_ids[:, -max_n:]
if not same_n:
# fill 0 for n < max_n
mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1))
norm = norm * mask.to(torch.float32)
norm = norm.norm(2, dim=-1, keepdim=True)
normed_ids = token_ids / norm

# concate p1 and p2 so we can check distance and vector in one conv1d
normed_n_ids = normed_ids[:, -max_n:]
normed_ids_p2 = normed_ids * normed_ids
ones_ids = torch.ones_like(normed_n_ids)
if not same_n:
# fill 0 for n < max_n
normed_n_ids = normed_n_ids * mask.to(torch.float32)
ones_ids = ones_ids * mask.to(torch.float32)
normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0)
normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0)

# check cos distance & check vector length
match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0]
match_norm, match_ones = match_norm.chunk(2, dim=0)

# both match result should be close to 1
one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1)
matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor)

# threshold
count = matched_mask.sum(-1)
found = (count >= threshold) & (threshold > 0)

return matched_mask, found


def _filter_repetition_ngram_(
scores: torch.Tensor,
stop_words: torch.Tensor,
generated_ids: torch.Tensor,
n: torch.Tensor | None,
threshold: torch.Tensor,
max_n: int,
ngram_window_size: torch.Tensor | None,
max_ngram_window_size: int,
):
"""Filter ngram.

if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist.
"""
if stop_words is None or stop_words.numel() == 0:
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, ngram_window_size, max_ngram_window_size)
stop_words = stop_words[:, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
stop_scores = scores.gather(1, stop_words[:, None])
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_words[:, None], stop_scores)
return scores


def _multinomial_sampling(scores: torch.Tensor,
seeds: torch.LongTensor,
offsets: torch.LongTensor,
Expand All @@ -84,7 +222,7 @@ def _multinomial_sampling(scores: torch.Tensor,
class SamplingInputsDelta:
num_ignore_eos: torch.Tensor = None
random_offsets: torch.Tensor = None
all_ids: Optional[torch.Tensor] = None
all_ids: None | torch.Tensor = None


@dataclass
Expand All @@ -104,16 +242,28 @@ class SamplingInputs:
min_top_p: float = 1.0
response_formats: Tuple[str] = ()
logits_processors: List[List[LogitsProcessor]] = None
max_num_logprobs: Optional[int] = None
all_ids: Optional[torch.Tensor] = None
max_num_logprobs: None | int = None
all_ids: None | torch.Tensor = None
num_ignore_eos: torch.Tensor = None
batch_size: int = 0
session_ctx: Optional[List[Dict[str, Any]]] = None
session_to_cleanup: Optional[List[int]] = None
session_ctx: None | List[Dict[str, Any]] = None
session_to_cleanup: None | List[int] = None
# for repetition_penalty and ngram
generated_ids: torch.Tensor | None = None
generated_ids_cpu: np.ndarray | None = None

# n gram
repetition_ngram_size: torch.Tensor | None = None
repetition_ngram_threshold: torch.Tensor | None = None
repetition_ngram_window_size: torch.Tensor | None = None
max_repetition_ngram_size: int = 0
max_repetition_ngram_window_size: int = 0

def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
out_dict = dict()
if self.generated_ids is None and self.generated_ids_cpu is not None:
self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())
for f in fields(self):
k = f.name
v = getattr(self, k)
Expand Down Expand Up @@ -168,8 +318,8 @@ class FusedLogitsProcessor:
def __init__(
self,
sampling_inputs: SamplingInputs,
logprobs_mode: Optional[str] = None,
guided_decoding_manager: Optional[GuidedDecodingManager] = None,
logprobs_mode: None | str = None,
guided_decoding_manager: None | GuidedDecodingManager = None,
):
self.sampling_inputs: SamplingInputs = sampling_inputs
self.logprobs_mode = logprobs_mode
Expand All @@ -189,18 +339,18 @@ async def _wait_stream_once(self):
if not stream.query():
await asyncio.sleep(0)

async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
async def __call__(self, scores: torch.Tensor) -> torch.Tensor:
r"""
Args:
scores (torch.FloatTensor):
scores (torch.Tensor):
Prediction scores of a language modeling head.
These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token
when using beam search


Return:
torch.FloatTensor: The processed prediction scores.
torch.Tensor: The processed prediction scores.

"""

Expand Down Expand Up @@ -238,7 +388,23 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:

repetition_penalty = sampling_inputs.repetition_penalty
if repetition_penalty is not None:
scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty)
generated_ids = sampling_inputs.generated_ids
scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty)

if sampling_inputs.max_repetition_ngram_size > 0:
generated_ids = sampling_inputs.generated_ids
assert generated_ids is not None
assert sampling_inputs.repetition_ngram_threshold is not None
scores = _filter_repetition_ngram_(
scores,
sampling_inputs.stop_words,
generated_ids,
sampling_inputs.repetition_ngram_size,
sampling_inputs.repetition_ngram_threshold,
sampling_inputs.max_repetition_ngram_size,
sampling_inputs.repetition_ngram_window_size,
sampling_inputs.max_repetition_ngram_window_size,
)

temperature = sampling_inputs.temperature
if temperature is not None:
Expand Down
24 changes: 16 additions & 8 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List

import numpy as np
import torch
Expand Down Expand Up @@ -56,13 +56,18 @@ class SamplingParam:
bad_words: List[int] = field(default_factory=list)
max_new_tokens: int = 512
min_new_tokens: int = 0
response_format: Optional[str] = None
logits_processors: Optional[List[LogitsProcessor]] = None
response_format: None | str = None
logits_processors: None | List[LogitsProcessor] = None
out_logits: bool = False
out_last_hidden_states: bool = False
num_logprobs: int = -1
return_routed_experts: bool = False

# ngram
repetition_ngram_size: int = 0
repetition_ngram_threshold: int = 0
repetition_ngram_window_size: int = 1024

@classmethod
def from_gen_config(cls, gen_config: GenerationConfig):
"""From gen config."""
Expand Down Expand Up @@ -144,6 +149,9 @@ def from_gen_config(cls, gen_config: GenerationConfig):
out_logits=(output_logits is not None),
num_logprobs=logprobs,
return_routed_experts=gen_config.return_routed_experts,
repetition_ngram_size=gen_config.repetition_ngram_size,
repetition_ngram_threshold=gen_config.repetition_ngram_threshold,
repetition_ngram_window_size=gen_config.repetition_ngram_window_size,
)


Expand Down Expand Up @@ -262,7 +270,7 @@ def add_sequence(self,
adapter_name: str = None,
multimodals: MultiModalInputs = None,
input_embeddings: List[InputEmbeddings] = None,
migration_request: Optional[MigrationRequest] = None,
migration_request: None | MigrationRequest = None,
resp_cache: bool = False,
preserve_cache: bool = False) -> 'SchedulerSequence':
"""Add a new message."""
Expand Down Expand Up @@ -604,7 +612,7 @@ class SchedulerSequence:
model_meta: Dict[str, Any] = None

# For Disaggregation
migration_request: Optional[MigrationRequest] = None
migration_request: None | MigrationRequest = None
resp_cache: bool = False
preserve_cache: bool = False

Expand Down Expand Up @@ -698,7 +706,7 @@ def routed_experts(self) -> np.ndarray:
else:
return None

def append_routed_experts(self, routed_experts: Union[Tensor, np.ndarray]):
def append_routed_experts(self, routed_experts: Tensor | np.ndarray):
"""Append routed experts."""
if not self.return_routed_experts:
return
Expand Down Expand Up @@ -756,7 +764,7 @@ def logits(self):
"""Get logits."""
return self.all_logits.get_logits()

def append_logits(self, logits: Union[Tensor, np.ndarray]):
def append_logits(self, logits: Tensor | np.ndarray):
"""Append logits."""
if not self.return_logits:
return
Expand All @@ -776,7 +784,7 @@ def get_input_multimodals(self):
def record_event(
self,
event_type: EventType,
timestamp: Optional[float] = None,
timestamp: None | float = None,
) -> None:
self.engine_events.append(EngineEvent.new_event(event_type, timestamp))

Expand Down
Loading
Loading