support repetition ngram logits processor#4288
support repetition ngram logits processor#4288grimoire wants to merge 3 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for an n-gram-based logits processor (intended to force generation of a stop token once repeated n-grams exceed a threshold), wiring new ngram_size / ngram_threshold parameters through sampling inputs and adding a unit test.
Changes:
- Add n-gram matching +
_filter_ngram_intoFusedLogitsProcessor. - Plumb
ngram_size/ngram_thresholdthroughGenerationConfig→SamplingParam→SamplingInputs, including new generated-token history gathering. - Add a unit test for
_filter_ngram_.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/engine/test_logits_process.py | Adds coverage for the new n-gram filtering behavior. |
| lmdeploy/pytorch/engine/logits_process.py | Implements n-gram matching/filtering and integrates it into fused logits processing. |
| lmdeploy/pytorch/strategies/ar/sampling.py | Gathers per-request n-gram params and generated-token history for GPU-side processing. |
| lmdeploy/pytorch/strategies/dllm/sampling.py | Repeats new sampling attributes across DLLM blocks and expands generated-id history. |
| lmdeploy/pytorch/messages.py | Adds ngram_size / ngram_threshold to SamplingParam to carry runtime settings. |
| lmdeploy/messages.py | Adds ngram_size / ngram_threshold to user-facing GenerationConfig. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_filter_ngram(): | ||
| from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ | ||
|
|
There was a problem hiding this comment.
PR description appears to still be the default template (Motivation/Modification/etc. are not filled in). Please add a brief motivation and summarize the intended behavior of the new n-gram logits processor (including how ngram_size/ngram_threshold interact with stop_words) to make review and downstream usage clearer.
| ): | ||
| """Filter ngram.""" | ||
| if stop_words is None or stop_words.numel() == 0: | ||
| return scores | ||
| # use first stop words | ||
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | ||
| stop_words = stop_words[:, 0] | ||
| # fill all scores -inf | ||
| scores.masked_fill_(found[:, None], -float('inf')) | ||
| # set stop words to 0 | ||
| stop_scores = scores.gather(1, stop_words[:, None]) | ||
| stop_scores.masked_fill_(found[:, None], 0) | ||
| scores.scatter_(1, stop_words[:, None], stop_scores) |
There was a problem hiding this comment.
_filter_ngram_ always uses stop_words[:, 0] without considering stop_mask. When some sequences have no stop words, __get_bad_words pads their row with 0s and a false mask; this function will still treat token 0 as a stop word and can force generation of token 0 when found is true. Pass/use stop_mask (select the first valid stop word per batch, and skip batches with none) or precompute a per-sequence stop token id when enabling n-gram stopping.
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # use first stop words | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| stop_words = stop_words[:, 0] | |
| # fill all scores -inf | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop words to 0 | |
| stop_scores = scores.gather(1, stop_words[:, None]) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_words[:, None], stop_scores) | |
| stop_mask: torch.Tensor | None = None, | |
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # determine which sequences have ngram matches | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| batch_size = scores.size(0) | |
| if stop_mask is not None: | |
| # has_stop indicates which batch elements have at least one valid stop word | |
| if stop_mask.numel() == 0: | |
| return scores | |
| has_stop = stop_mask.any(dim=1) | |
| if not has_stop.any(): | |
| # no sequences have valid stop words, nothing to do | |
| return scores | |
| # compute index of first valid stop word per batch element | |
| num_stops = stop_mask.size(1) | |
| idxs = torch.arange( | |
| num_stops, device=stop_mask.device, dtype=torch.long | |
| ).unsqueeze(0).expand_as(stop_mask) | |
| idxs = idxs.masked_fill(~stop_mask, num_stops) | |
| first_idxs = idxs.argmin(dim=1) | |
| batch_indices = torch.arange(batch_size, device=stop_words.device, dtype=torch.long) | |
| stop_tokens = stop_words[batch_indices, first_idxs] | |
| # only apply forcing where both an ngram is found and a valid stop word exists | |
| valid_found = found & has_stop | |
| if not valid_found.any(): | |
| return scores | |
| scores.masked_fill_(valid_found[:, None], -float('inf')) | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(valid_found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) | |
| else: | |
| # fallback: use the first stop word in each row, as originally implemented | |
| stop_tokens = stop_words[:, 0] | |
| # fill all scores -inf where an ngram is found | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop word scores to 0 | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) |
| if self.generated_ids_cpu is not None: | ||
| self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) | ||
| for f in fields(self): | ||
| k = f.name | ||
| v = getattr(self, k) |
There was a problem hiding this comment.
SamplingInputs.to_device() mutates self by setting self.generated_ids before building the output instance, and it forces a full .copy() of generated_ids_cpu. This introduces side effects and an extra memory copy on every device transfer. Prefer keeping to_device pure (compute a local tensor from generated_ids_cpu and put it into out_dict) and avoid the unconditional copy unless it’s required for contiguity.
| if self.generated_ids_cpu is not None: | |
| self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) | |
| for f in fields(self): | |
| k = f.name | |
| v = getattr(self, k) | |
| # Build a local `generated_ids` tensor from `generated_ids_cpu` without | |
| # mutating `self` and without forcing an unnecessary copy. | |
| if self.generated_ids_cpu is not None: | |
| gen_np = self.generated_ids_cpu | |
| # Ensure the array is suitable for torch.from_numpy without an | |
| # extra copy unless required for contiguity or writability. | |
| if (not gen_np.flags['C_CONTIGUOUS']) or (not gen_np.flags['WRITEABLE']): | |
| gen_np = np.ascontiguousarray(gen_np) | |
| generated_ids = torch.from_numpy(gen_np) | |
| else: | |
| generated_ids = self.generated_ids | |
| for f in fields(self): | |
| k = f.name | |
| if k == 'generated_ids': | |
| v = generated_ids | |
| else: | |
| v = getattr(self, k) |
repetition stopping is implemented as logits processor.
If we turn it to an engine-level feature, the implementation would be much easy and the performance would be better.