From d942dfd8cc1a5221d5ae1204a11e70f017efbfea Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 23 Jan 2026 18:19:28 +0800 Subject: [PATCH 1/3] support ngram logits processor --- lmdeploy/messages.py | 4 + lmdeploy/pytorch/engine/logits_process.py | 156 +++++++++++++++++-- lmdeploy/pytorch/messages.py | 22 ++- lmdeploy/pytorch/strategies/ar/sampling.py | 41 ++++- lmdeploy/pytorch/strategies/dllm/sampling.py | 24 ++- tests/pytorch/engine/test_logits_process.py | 25 +++ 6 files changed, 251 insertions(+), 21 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 8b54705e31..00da23baf6 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -129,6 +129,10 @@ class GenerationConfig: # router replay return_routed_experts: bool = False + # ngram + ngram_size: int = 0 + ngram_threshold: int = 0 + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to stop_token_ids/bad_token_ids.""" diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index d6b5542581..9190e14925 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple +from functools import lru_cache +from typing import Any, Dict, List, Tuple +import numpy as np import torch from lmdeploy.messages import LogitsProcessor @@ -29,7 +31,7 @@ def _process_bad_words_(scores: torch.Tensor, return scores -def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor): +def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor): """Process repetition penalty.""" score = torch.gather(scores, 1, input_ids) penalty = penalty.to(score.dtype) @@ -68,6 +70,116 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: return scores +@lru_cache(maxsize=1) +def _ngram_one(dtype: torch.dtype, device: torch.device): + return torch.ones(1, dtype=dtype, device=device) + + +def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max_n: int, same_n: bool = False): + """Compute n-gram matches between sliding windows and a target sequence. + + For each batch, performs cosine similarity checking between: + - All sliding windows of length `max_n` from the full sequence + - The last `max_n` tokens of the sequence (target window) + + A match is counted when both: + 1. Cosine similarity ≈ 1 (normalized vectors match) + 2. Vector lengths match (preventing zero/normalization artifacts) + + Parameters + ---------- + token_ids : torch.Tensor + Input token IDs of shape (batch_size, seq_len). + Values are typically ≥0 (0 may represent padding/special tokens). + n : torch.Tensor + Effective n-gram length for each batch element, shape (batch_size,). + When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked. + threshold : torch.Tensor + Minimum number of matching windows required for validity, shape (batch_size,). + max_n : int + Maximum n-gram length (window size for matching). + same_n : bool, default False + If True, use full `max_n`-length windows regardless of `n`. + If False, mask positions where index < (max_n - n) in the target window. + + Returns + ------- + matched_mask : torch.Tensor + Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating + which sliding windows match the target n-gram. + found : torch.Tensor + Boolean tensor of shape (batch_size,) indicating whether each batch + element has at least `threshold` matches. + """ + + batch_size, seq_len = token_ids.size() + if seq_len < max_n: + # Not enough tokens to form a single n-gram + matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device) + found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device) + return matched_mask, found + # token_ids could be 0, so we add 1 to avoid div 0 + token_ids = token_ids.to(torch.float32) + 1 + + # normalize ids + norm = token_ids[:, -max_n:] + if not same_n: + # fill 0 for n < max_n + mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1)) + norm = norm * mask.to(torch.float32) + norm = norm.norm(2, dim=-1, keepdim=True) + normed_ids = token_ids / norm + + # concate p1 and p2 so we can check distance and vector in one conv1d + normed_n_ids = normed_ids[:, -max_n:] + normed_ids_p2 = normed_ids * normed_ids + ones_ids = torch.ones_like(normed_n_ids) + if not same_n: + # fill 0 for n < max_n + normed_n_ids = normed_n_ids * mask.to(torch.float32) + ones_ids = ones_ids * mask.to(torch.float32) + normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0) + normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0) + + # check cos distance & check vector length + match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0] + match_norm, match_ones = match_norm.chunk(2, dim=0) + + # both match result should be close to 1 + one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device) + matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor) + + # threshold + count = matched_mask.sum(-1) + found = (count >= threshold) & (threshold > 0) + + return matched_mask, found + + +def _filter_ngram_( + scores: torch.Tensor, + stop_words: torch.Tensor, + generated_ids: torch.Tensor, + n: torch.Tensor, + threshold: torch.Tensor, + max_n: int, + same_n: bool = False, +): + """Filter ngram.""" + if stop_words is None or stop_words.numel() == 0: + return scores + # use first stop words + _, found = ngram(generated_ids, n, threshold, max_n, same_n) + stop_words = stop_words[:, 0] + # fill all scores -inf + scores.masked_fill_(found[:, None], -float('inf')) + # set stop words to 0 + stop_scores = scores.gather(1, stop_words[:, None]) + stop_scores.masked_fill_(found[:, None], 0) + scores.scatter_(1, stop_words[:, None], stop_scores) + return scores + + def _multinomial_sampling(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, @@ -84,7 +196,7 @@ def _multinomial_sampling(scores: torch.Tensor, class SamplingInputsDelta: num_ignore_eos: torch.Tensor = None random_offsets: torch.Tensor = None - all_ids: Optional[torch.Tensor] = None + all_ids: None | torch.Tensor = None @dataclass @@ -104,16 +216,27 @@ class SamplingInputs: min_top_p: float = 1.0 response_formats: Tuple[str] = () logits_processors: List[List[LogitsProcessor]] = None - max_num_logprobs: Optional[int] = None - all_ids: Optional[torch.Tensor] = None + max_num_logprobs: None | int = None + all_ids: None | torch.Tensor = None num_ignore_eos: torch.Tensor = None batch_size: int = 0 - session_ctx: Optional[List[Dict[str, Any]]] = None - session_to_cleanup: Optional[List[int]] = None + session_ctx: None | List[Dict[str, Any]] = None + session_to_cleanup: None | List[int] = None + # for repetition_penalty and ngram + generated_ids: torch.Tensor | None = None + generated_ids_cpu: np.ndarray | None = None + + # n gram + ngram_size: torch.Tensor = None + ngram_threshold: torch.Tensor = None + max_ngram_size: int = 0 + ngram_same_n: bool = False def to_device(self, device: str, non_blocking: bool = False): """To device.""" out_dict = dict() + if self.generated_ids_cpu is not None: + self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) for f in fields(self): k = f.name v = getattr(self, k) @@ -168,8 +291,8 @@ class FusedLogitsProcessor: def __init__( self, sampling_inputs: SamplingInputs, - logprobs_mode: Optional[str] = None, - guided_decoding_manager: Optional[GuidedDecodingManager] = None, + logprobs_mode: None | str = None, + guided_decoding_manager: None | GuidedDecodingManager = None, ): self.sampling_inputs: SamplingInputs = sampling_inputs self.logprobs_mode = logprobs_mode @@ -238,7 +361,20 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: repetition_penalty = sampling_inputs.repetition_penalty if repetition_penalty is not None: - scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty) + generated_ids = sampling_inputs.generated_ids + scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) + + if sampling_inputs.max_ngram_size > 0: + generated_ids = sampling_inputs.generated_ids + scores = _filter_ngram_( + scores, + sampling_inputs.stop_words, + generated_ids, + sampling_inputs.ngram_size, + sampling_inputs.ngram_threshold, + sampling_inputs.max_ngram_size, + sampling_inputs.ngram_same_n, + ) temperature = sampling_inputs.temperature if temperature is not None: diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index c020403fa8..fb71465d1e 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -2,7 +2,7 @@ import enum from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List import numpy as np import torch @@ -56,13 +56,17 @@ class SamplingParam: bad_words: List[int] = field(default_factory=list) max_new_tokens: int = 512 min_new_tokens: int = 0 - response_format: Optional[str] = None - logits_processors: Optional[List[LogitsProcessor]] = None + response_format: None | str = None + logits_processors: None | List[LogitsProcessor] = None out_logits: bool = False out_last_hidden_states: bool = False num_logprobs: int = -1 return_routed_experts: bool = False + # ngram + ngram_size: int = 0 + ngram_threshold: int = 0 + @classmethod def from_gen_config(cls, gen_config: GenerationConfig): """From gen config.""" @@ -144,6 +148,8 @@ def from_gen_config(cls, gen_config: GenerationConfig): out_logits=(output_logits is not None), num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, + ngram_size=gen_config.ngram_size, + ngram_threshold=gen_config.ngram_threshold, ) @@ -262,7 +268,7 @@ def add_sequence(self, adapter_name: str = None, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, - migration_request: Optional[MigrationRequest] = None, + migration_request: None | MigrationRequest = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" @@ -604,7 +610,7 @@ class SchedulerSequence: model_meta: Dict[str, Any] = None # For Disaggregation - migration_request: Optional[MigrationRequest] = None + migration_request: None | MigrationRequest = None resp_cache: bool = False preserve_cache: bool = False @@ -698,7 +704,7 @@ def routed_experts(self) -> np.ndarray: else: return None - def append_routed_experts(self, routed_experts: Union[Tensor, np.ndarray]): + def append_routed_experts(self, routed_experts: Tensor | np.ndarray): """Append routed experts.""" if not self.return_routed_experts: return @@ -756,7 +762,7 @@ def logits(self): """Get logits.""" return self.all_logits.get_logits() - def append_logits(self, logits: Union[Tensor, np.ndarray]): + def append_logits(self, logits: Tensor | np.ndarray): """Append logits.""" if not self.return_logits: return @@ -776,7 +782,7 @@ def get_input_multimodals(self): def record_event( self, event_type: EventType, - timestamp: Optional[float] = None, + timestamp: None | float = None, ) -> None: self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 4db051ddcf..cb95308af8 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +import numpy as np import torch from torch.profiler import record_function @@ -15,7 +16,7 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs): """Gather history.""" - if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): + if not any(sampling_inputs.logits_processors): return None batch = len(seqs) max_len = max(seq.num_valid_ids for seq in seqs) @@ -29,6 +30,22 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) return output +def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None: + """Gather history.""" + if sampling_inputs.repetition_penalty is None and sampling_inputs.max_ngram_size == 0: + return None + batch = len(seqs) + max_len = max(seq.num_new_tokens for seq in seqs) + output = np.full((batch, max_len), pad_id, dtype=np.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_new_tokens + if h_len == 0: + continue + h_ids = seq.generated_ids + output[idx, -h_len:] = h_ids + return output + + def _get_num_ignore_eos(seqs: SeqList): """Get num ignore eos.""" ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] @@ -61,6 +78,8 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: num_logprobs = [None] * batch_size session_to_cleanup = self.session_to_cleanup self.session_to_cleanup = [] + ngram_sizes = [None] * batch_size + ngram_thresholds = [None] * batch_size def __gather_params(): """Gather params.""" @@ -84,6 +103,8 @@ def __gather_params(): stop_words[idx] = sw logits_processors[idx] = param.logits_processors num_logprobs[idx] = param.num_logprobs + ngram_sizes[idx] = param.ngram_size + ngram_thresholds[idx] = param.ngram_threshold def __get_topp(top_p): """Get topp.""" @@ -165,6 +186,19 @@ def __get_bad_words(bad_words): 'seq_id': seq.seq_id, } for seq in seqs] + # ngram + max_ngram_size = max(ngram_sizes) + if max_ngram_size == 0: + ngram_sizes = None + ngram_thresholds = None + ngram_same_n = True + else: + ngram_sizes = torch.tensor(ngram_sizes) + ngram_thresholds = torch.tensor(ngram_thresholds) + ngram_same_n = (ngram_sizes == max_ngram_size).all().item() + if ngram_same_n: + ngram_sizes = None + sampling_input = SamplingInputs( temperature=temperature, bad_words=bad_words, @@ -185,10 +219,15 @@ def __get_bad_words(bad_words): batch_size=batch_size, session_ctx=session_ctx, session_to_cleanup=session_to_cleanup, + ngram_size=ngram_sizes, + ngram_threshold=ngram_thresholds, + max_ngram_size=max_ngram_size, + ngram_same_n=ngram_same_n, ) pad_token_id = self.pad_token_id sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) + sampling_input.generated_ids_cpu = _gather_generated_ids(pad_token_id, seqs, sampling_input) sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) return sampling_input diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 5a027e922d..d7c8bc4716 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +import numpy as np import torch from torch.profiler import record_function @@ -42,15 +43,34 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'random_offsets', 'all_ids', 'num_ignore_eos', + 'ngram_size', + 'ngram_threshold', ] for name in update_attr_names: attr = getattr(out, name) if attr is None: continue - repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) - attr = attr[None].repeat(*repeats).flatten(0, 1) + if attr.dim() == 1: + repeats = (dllm_block_length, 1) + attr = attr[None].repeat(*repeats).flatten(0, 1) + elif attr.dim() == 2: + repeats = (1, dllm_block_length, 1) + attr = attr[:, None].repeat(*repeats).flatten(0, 1) + else: + repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) + attr = attr[None].repeat(*repeats).flatten(0, 1) setattr(out, name, attr) + # update generated_ids_cpu + if out.generated_ids_cpu is not None: + generated_ids_cpu = out.generated_ids_cpu + if generated_ids_cpu.shape[1] == 0: + out.generated_ids_cpu = np.repeat(generated_ids_cpu, dllm_block_length, axis=0) + else: + generated_ids_cpu = np.repeat(generated_ids_cpu[:, None], dllm_block_length, axis=1) + generated_ids_cpu = np.reshape(generated_ids_cpu, (-1, generated_ids_cpu.shape[-1])) + out.generated_ids_cpu = generated_ids_cpu + if len(out.response_formats) > 0: new_resp_formats = [] for resp in out.response_formats: diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index b901879be4..7a85a62d4f 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -124,3 +124,28 @@ def test_filter_minp_sorted(): out = _filter_minp_sorted_(scores, min_p) torch.testing.assert_close(out, gt) + + +def test_filter_ngram(): + from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ + + generated_ids = torch.tensor( + [[2, 3, 4, 1, 2, 3, 4, 2, 3, 4], [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], [9, 8, 7, 3, 8, 7, 5, 9, 8, 7]], + dtype=torch.int64) + n = torch.tensor([3, 3, 2], dtype=torch.int64) + threshold = torch.tensor([3, 3, 3], dtype=torch.int64) + + batch_size = generated_ids.size(0) + max_n = n.max().item() + same_n = n.eq(max_n).all().item() + vocab_size = 100 + + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, same_n) + + assert not scores[1].isinf().any().item() + assert scores[0].isinf().sum().item() == vocab_size - 1 + assert scores[2].isinf().sum().item() == vocab_size - 1 + assert scores[0, stop_words[0, 0]] == 0 + assert scores[2, stop_words[2, 0]] == 0 From 0f0b8dfba9214c8f9e3c327f08a58a68a9742fdd Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 26 Jan 2026 12:28:51 +0800 Subject: [PATCH 2/3] add window size --- lmdeploy/messages.py | 1 + lmdeploy/pytorch/engine/logits_process.py | 72 +++++++++++++++------ lmdeploy/pytorch/messages.py | 2 + lmdeploy/pytorch/strategies/ar/sampling.py | 12 +++- tests/pytorch/engine/test_logits_process.py | 61 ++++++++++++++--- 5 files changed, 116 insertions(+), 32 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 00da23baf6..c3b394f2b1 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -132,6 +132,7 @@ class GenerationConfig: # ngram ngram_size: int = 0 ngram_threshold: int = 0 + ngram_window_size: int = 1024 def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 9190e14925..0773bacafa 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -70,12 +70,19 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: return scores -@lru_cache(maxsize=1) -def _ngram_one(dtype: torch.dtype, device: torch.device): - return torch.ones(1, dtype=dtype, device=device) +@lru_cache +def _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1): + return torch.ones(fill, dtype=dtype, device=device) -def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max_n: int, same_n: bool = False): +def ngram( + token_ids: torch.Tensor, + n: torch.Tensor | None, + threshold: torch.Tensor, + max_n: int, + window_size: torch.Tensor | None, + max_window_size: int, +): """Compute n-gram matches between sliding windows and a target sequence. For each batch, performs cosine similarity checking between: @@ -98,9 +105,11 @@ def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max Minimum number of matching windows required for validity, shape (batch_size,). max_n : int Maximum n-gram length (window size for matching). - same_n : bool, default False - If True, use full `max_n`-length windows regardless of `n`. - If False, mask positions where index < (max_n - n) in the target window. + window_size: torch.Tensor | None + Effective window size for each batch element, shape (batch_size,). + When `same_n=False`, only the last `window_size` tokens are considered for matching. + max_window_size: int + Maximum window size for matching. Returns ------- @@ -119,9 +128,22 @@ def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device) return matched_mask, found # token_ids could be 0, so we add 1 to avoid div 0 - token_ids = token_ids.to(torch.float32) + 1 + token_ids = token_ids.to(torch.float32).log2() + 1 + + # Trim to max_window_size + if seq_len >= max_window_size: + token_ids = token_ids[:, -max_window_size:] + max_window_size = token_ids.size(1) + + same_window = window_size is None + if not same_window: + # fill -1 for window_size < max_window_size + mask = torch.arange(max_window_size, + device=token_ids.device).unsqueeze(0) >= (max_window_size - window_size.unsqueeze(1)) + token_ids.masked_fill_(~mask, -1) # normalize ids + same_n = n is None norm = token_ids[:, -max_n:] if not same_n: # fill 0 for n < max_n @@ -146,7 +168,7 @@ def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max match_norm, match_ones = match_norm.chunk(2, dim=0) # both match result should be close to 1 - one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device) + one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1) matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor) # threshold @@ -160,16 +182,20 @@ def _filter_ngram_( scores: torch.Tensor, stop_words: torch.Tensor, generated_ids: torch.Tensor, - n: torch.Tensor, + n: torch.Tensor | None, threshold: torch.Tensor, max_n: int, - same_n: bool = False, + ngram_window_size: torch.Tensor | None, + max_ngram_window_size: int, ): - """Filter ngram.""" + """Filter ngram. + + if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist. + """ if stop_words is None or stop_words.numel() == 0: return scores # use first stop words - _, found = ngram(generated_ids, n, threshold, max_n, same_n) + _, found = ngram(generated_ids, n, threshold, max_n, ngram_window_size, max_ngram_window_size) stop_words = stop_words[:, 0] # fill all scores -inf scores.masked_fill_(found[:, None], -float('inf')) @@ -227,15 +253,16 @@ class SamplingInputs: generated_ids_cpu: np.ndarray | None = None # n gram - ngram_size: torch.Tensor = None - ngram_threshold: torch.Tensor = None + ngram_size: torch.Tensor | None = None + ngram_threshold: torch.Tensor | None = None + ngram_window_size: torch.Tensor | None = None max_ngram_size: int = 0 - ngram_same_n: bool = False + max_ngram_window_size: int = 0 def to_device(self, device: str, non_blocking: bool = False): """To device.""" out_dict = dict() - if self.generated_ids_cpu is not None: + if self.generated_ids is None and self.generated_ids_cpu is not None: self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) for f in fields(self): k = f.name @@ -312,10 +339,10 @@ async def _wait_stream_once(self): if not stream.query(): await asyncio.sleep(0) - async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: + async def __call__(self, scores: torch.Tensor) -> torch.Tensor: r""" Args: - scores (torch.FloatTensor): + scores (torch.Tensor): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token @@ -323,7 +350,7 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: Return: - torch.FloatTensor: The processed prediction scores. + torch.Tensor: The processed prediction scores. """ @@ -366,6 +393,8 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: if sampling_inputs.max_ngram_size > 0: generated_ids = sampling_inputs.generated_ids + assert generated_ids is not None + assert sampling_inputs.ngram_threshold is not None scores = _filter_ngram_( scores, sampling_inputs.stop_words, @@ -373,7 +402,8 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: sampling_inputs.ngram_size, sampling_inputs.ngram_threshold, sampling_inputs.max_ngram_size, - sampling_inputs.ngram_same_n, + sampling_inputs.ngram_window_size, + sampling_inputs.max_ngram_window_size, ) temperature = sampling_inputs.temperature diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index fb71465d1e..60f9bc1f33 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -66,6 +66,7 @@ class SamplingParam: # ngram ngram_size: int = 0 ngram_threshold: int = 0 + ngram_window_size: int = 1024 @classmethod def from_gen_config(cls, gen_config: GenerationConfig): @@ -150,6 +151,7 @@ def from_gen_config(cls, gen_config: GenerationConfig): return_routed_experts=gen_config.return_routed_experts, ngram_size=gen_config.ngram_size, ngram_threshold=gen_config.ngram_threshold, + ngram_window_size=gen_config.ngram_window_size, ) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index cb95308af8..e69d52edec 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -80,6 +80,7 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: self.session_to_cleanup = [] ngram_sizes = [None] * batch_size ngram_thresholds = [None] * batch_size + ngram_window_sizes = [None] * batch_size def __gather_params(): """Gather params.""" @@ -105,6 +106,7 @@ def __gather_params(): num_logprobs[idx] = param.num_logprobs ngram_sizes[idx] = param.ngram_size ngram_thresholds[idx] = param.ngram_threshold + ngram_window_sizes[idx] = param.ngram_window_size if param.ngram_window_size > 0 else 1 << 62 def __get_topp(top_p): """Get topp.""" @@ -188,16 +190,21 @@ def __get_bad_words(bad_words): # ngram max_ngram_size = max(ngram_sizes) + max_ngram_window_size = max(ngram_window_sizes) if max_ngram_size == 0: ngram_sizes = None ngram_thresholds = None - ngram_same_n = True + ngram_window_sizes = None else: ngram_sizes = torch.tensor(ngram_sizes) ngram_thresholds = torch.tensor(ngram_thresholds) + ngram_window_sizes = torch.tensor(ngram_window_sizes) ngram_same_n = (ngram_sizes == max_ngram_size).all().item() if ngram_same_n: ngram_sizes = None + ngram_same_window_size = (ngram_window_sizes == max_ngram_window_size).all().item() + if ngram_same_window_size: + ngram_window_sizes = None sampling_input = SamplingInputs( temperature=temperature, @@ -221,8 +228,9 @@ def __get_bad_words(bad_words): session_to_cleanup=session_to_cleanup, ngram_size=ngram_sizes, ngram_threshold=ngram_thresholds, + ngram_window_size=ngram_window_sizes, max_ngram_size=max_ngram_size, - ngram_same_n=ngram_same_n, + max_ngram_window_size=max_ngram_window_size, ) pad_token_id = self.pad_token_id diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index 7a85a62d4f..aa9ef347f0 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -128,24 +128,67 @@ def test_filter_minp_sorted(): def test_filter_ngram(): from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ + vocab_size = 100 - generated_ids = torch.tensor( - [[2, 3, 4, 1, 2, 3, 4, 2, 3, 4], [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], [9, 8, 7, 3, 8, 7, 5, 9, 8, 7]], - dtype=torch.int64) + def _get_emtas(n, window_size): + batch_size = generated_ids.size(0) + max_n = int(n.max().item()) + same_n = n.eq(max_n).all().item() + max_window_size = int(window_size.max().item()) + if same_n: + n = None + return batch_size, max_n, max_window_size, n + + # base test + generated_ids = torch.tensor([ + [2, 3, 4, 1, 2, 3, 4, 2, 3, 4], + [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], + [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], + ], + dtype=torch.int64) n = torch.tensor([3, 3, 2], dtype=torch.int64) threshold = torch.tensor([3, 3, 3], dtype=torch.int64) + window_size = torch.tensor([10, 10, 10], dtype=torch.int64) - batch_size = generated_ids.size(0) - max_n = n.max().item() - same_n = n.eq(max_n).all().item() - vocab_size = 100 - + batch_size, max_n, max_window_size, n = _get_emtas(n, window_size) scores = torch.rand(batch_size, vocab_size) stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) - _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, same_n) + _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) assert not scores[1].isinf().any().item() assert scores[0].isinf().sum().item() == vocab_size - 1 assert scores[2].isinf().sum().item() == vocab_size - 1 assert scores[0, stop_words[0, 0]] == 0 assert scores[2, stop_words[2, 0]] == 0 + + # test no ngram + generated_ids = torch.tensor([ + [2, 3, 4, 1, 2, 3, 4, 2, 3, 4], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]) + n = torch.tensor([3, 0], dtype=torch.int64) + threshold = torch.tensor([3, 0], dtype=torch.int64) + window_size = torch.tensor([10, 10], dtype=torch.int64) + batch_size, max_n, max_window_size, n = _get_emtas(n, window_size) + + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) + assert not scores[1].isinf().any().item() + assert scores[0].isinf().sum().item() == vocab_size - 1 + + # test window + generated_ids = torch.tensor([ + [2, 3, 4, 1, 2, 3, 4, 2, 3, 4], + [2, 3, 4, 1, 2, 3, 4, 2, 3, 4], + ]) + n = torch.tensor([2, 0], dtype=torch.int64) + threshold = torch.tensor([3, 3], dtype=torch.int64) + window_size = torch.tensor([10, 6], dtype=torch.int64) + batch_size, max_n, max_window_size, n = _get_emtas(n, window_size) + + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) + assert not scores[1].isinf().any().item() + assert scores[0].isinf().sum().item() == vocab_size - 1 From d981a3cf9ddb9249e155e505947ffe4525bbb11c Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 26 Jan 2026 14:47:58 +0800 Subject: [PATCH 3/3] rename --- lmdeploy/messages.py | 6 +-- lmdeploy/pytorch/engine/logits_process.py | 28 +++++----- lmdeploy/pytorch/messages.py | 12 ++--- lmdeploy/pytorch/strategies/ar/sampling.py | 58 +++++++++++---------- tests/pytorch/engine/test_logits_process.py | 8 +-- 5 files changed, 57 insertions(+), 55 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index c3b394f2b1..b69ee7bc7b 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -130,9 +130,9 @@ class GenerationConfig: return_routed_experts: bool = False # ngram - ngram_size: int = 0 - ngram_threshold: int = 0 - ngram_window_size: int = 1024 + repetition_ngram_size: int = 0 + repetition_ngram_threshold: int = 0 + repetition_ngram_window_size: int = 1024 def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 0773bacafa..849073a5aa 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -178,7 +178,7 @@ def ngram( return matched_mask, found -def _filter_ngram_( +def _filter_repetition_ngram_( scores: torch.Tensor, stop_words: torch.Tensor, generated_ids: torch.Tensor, @@ -253,11 +253,11 @@ class SamplingInputs: generated_ids_cpu: np.ndarray | None = None # n gram - ngram_size: torch.Tensor | None = None - ngram_threshold: torch.Tensor | None = None - ngram_window_size: torch.Tensor | None = None - max_ngram_size: int = 0 - max_ngram_window_size: int = 0 + repetition_ngram_size: torch.Tensor | None = None + repetition_ngram_threshold: torch.Tensor | None = None + repetition_ngram_window_size: torch.Tensor | None = None + max_repetition_ngram_size: int = 0 + max_repetition_ngram_window_size: int = 0 def to_device(self, device: str, non_blocking: bool = False): """To device.""" @@ -391,19 +391,19 @@ async def __call__(self, scores: torch.Tensor) -> torch.Tensor: generated_ids = sampling_inputs.generated_ids scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) - if sampling_inputs.max_ngram_size > 0: + if sampling_inputs.max_repetition_ngram_size > 0: generated_ids = sampling_inputs.generated_ids assert generated_ids is not None - assert sampling_inputs.ngram_threshold is not None - scores = _filter_ngram_( + assert sampling_inputs.repetition_ngram_threshold is not None + scores = _filter_repetition_ngram_( scores, sampling_inputs.stop_words, generated_ids, - sampling_inputs.ngram_size, - sampling_inputs.ngram_threshold, - sampling_inputs.max_ngram_size, - sampling_inputs.ngram_window_size, - sampling_inputs.max_ngram_window_size, + sampling_inputs.repetition_ngram_size, + sampling_inputs.repetition_ngram_threshold, + sampling_inputs.max_repetition_ngram_size, + sampling_inputs.repetition_ngram_window_size, + sampling_inputs.max_repetition_ngram_window_size, ) temperature = sampling_inputs.temperature diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 60f9bc1f33..495478b162 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -64,9 +64,9 @@ class SamplingParam: return_routed_experts: bool = False # ngram - ngram_size: int = 0 - ngram_threshold: int = 0 - ngram_window_size: int = 1024 + repetition_ngram_size: int = 0 + repetition_ngram_threshold: int = 0 + repetition_ngram_window_size: int = 1024 @classmethod def from_gen_config(cls, gen_config: GenerationConfig): @@ -149,9 +149,9 @@ def from_gen_config(cls, gen_config: GenerationConfig): out_logits=(output_logits is not None), num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, - ngram_size=gen_config.ngram_size, - ngram_threshold=gen_config.ngram_threshold, - ngram_window_size=gen_config.ngram_window_size, + repetition_ngram_size=gen_config.repetition_ngram_size, + repetition_ngram_threshold=gen_config.repetition_ngram_threshold, + repetition_ngram_window_size=gen_config.repetition_ngram_window_size, ) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index e69d52edec..72278efb6b 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -32,7 +32,7 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None: """Gather history.""" - if sampling_inputs.repetition_penalty is None and sampling_inputs.max_ngram_size == 0: + if sampling_inputs.repetition_penalty is None and sampling_inputs.max_repetition_ngram_size == 0: return None batch = len(seqs) max_len = max(seq.num_new_tokens for seq in seqs) @@ -78,9 +78,9 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: num_logprobs = [None] * batch_size session_to_cleanup = self.session_to_cleanup self.session_to_cleanup = [] - ngram_sizes = [None] * batch_size - ngram_thresholds = [None] * batch_size - ngram_window_sizes = [None] * batch_size + repetition_ngram_sizes = [None] * batch_size + repetition_ngram_thresholds = [None] * batch_size + repetition_ngram_window_sizes = [None] * batch_size def __gather_params(): """Gather params.""" @@ -104,9 +104,10 @@ def __gather_params(): stop_words[idx] = sw logits_processors[idx] = param.logits_processors num_logprobs[idx] = param.num_logprobs - ngram_sizes[idx] = param.ngram_size - ngram_thresholds[idx] = param.ngram_threshold - ngram_window_sizes[idx] = param.ngram_window_size if param.ngram_window_size > 0 else 1 << 62 + repetition_ngram_sizes[idx] = param.repetition_ngram_size + repetition_ngram_thresholds[idx] = param.repetition_ngram_threshold + repetition_ngram_window_sizes[ + idx] = param.repetition_ngram_window_size if param.repetition_ngram_window_size > 0 else 1 << 62 def __get_topp(top_p): """Get topp.""" @@ -188,23 +189,24 @@ def __get_bad_words(bad_words): 'seq_id': seq.seq_id, } for seq in seqs] - # ngram - max_ngram_size = max(ngram_sizes) - max_ngram_window_size = max(ngram_window_sizes) - if max_ngram_size == 0: - ngram_sizes = None - ngram_thresholds = None - ngram_window_sizes = None + # repetition ngram + max_repetition_ngram_size = max(repetition_ngram_sizes) + max_repetition_ngram_window_size = max(repetition_ngram_window_sizes) + if max_repetition_ngram_size == 0: + repetition_ngram_sizes = None + repetition_ngram_thresholds = None + repetition_ngram_window_sizes = None else: - ngram_sizes = torch.tensor(ngram_sizes) - ngram_thresholds = torch.tensor(ngram_thresholds) - ngram_window_sizes = torch.tensor(ngram_window_sizes) - ngram_same_n = (ngram_sizes == max_ngram_size).all().item() - if ngram_same_n: - ngram_sizes = None - ngram_same_window_size = (ngram_window_sizes == max_ngram_window_size).all().item() - if ngram_same_window_size: - ngram_window_sizes = None + repetition_ngram_sizes = torch.tensor(repetition_ngram_sizes) + repetition_ngram_thresholds = torch.tensor(repetition_ngram_thresholds) + repetition_ngram_window_sizes = torch.tensor(repetition_ngram_window_sizes) + repetition_ngram_same_n = (repetition_ngram_sizes == max_repetition_ngram_size).all().item() + if repetition_ngram_same_n: + repetition_ngram_sizes = None + repetition_ngram_same_window_size = ( + repetition_ngram_window_sizes == max_repetition_ngram_window_size).all().item() + if repetition_ngram_same_window_size: + repetition_ngram_window_sizes = None sampling_input = SamplingInputs( temperature=temperature, @@ -226,11 +228,11 @@ def __get_bad_words(bad_words): batch_size=batch_size, session_ctx=session_ctx, session_to_cleanup=session_to_cleanup, - ngram_size=ngram_sizes, - ngram_threshold=ngram_thresholds, - ngram_window_size=ngram_window_sizes, - max_ngram_size=max_ngram_size, - max_ngram_window_size=max_ngram_window_size, + repetition_ngram_size=repetition_ngram_sizes, + repetition_ngram_threshold=repetition_ngram_thresholds, + repetition_ngram_window_size=repetition_ngram_window_sizes, + max_repetition_ngram_size=max_repetition_ngram_size, + max_repetition_ngram_window_size=max_repetition_ngram_window_size, ) pad_token_id = self.pad_token_id diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index aa9ef347f0..49a7a14e4d 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -127,7 +127,7 @@ def test_filter_minp_sorted(): def test_filter_ngram(): - from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ + from lmdeploy.pytorch.engine.logits_process import _filter_repetition_ngram_ vocab_size = 100 def _get_emtas(n, window_size): @@ -153,7 +153,7 @@ def _get_emtas(n, window_size): batch_size, max_n, max_window_size, n = _get_emtas(n, window_size) scores = torch.rand(batch_size, vocab_size) stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) - _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) + _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) assert not scores[1].isinf().any().item() assert scores[0].isinf().sum().item() == vocab_size - 1 @@ -173,7 +173,7 @@ def _get_emtas(n, window_size): scores = torch.rand(batch_size, vocab_size) stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) - _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) + _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) assert not scores[1].isinf().any().item() assert scores[0].isinf().sum().item() == vocab_size - 1 @@ -189,6 +189,6 @@ def _get_emtas(n, window_size): scores = torch.rand(batch_size, vocab_size) stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) - _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) + _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, window_size, max_window_size) assert not scores[1].isinf().any().item() assert scores[0].isinf().sum().item() == vocab_size - 1