Skip to content

Conversation

@jzakrzew
Copy link
Contributor

@jzakrzew jzakrzew commented Dec 12, 2025

TLDR

  • nvidia/llama-nemotron-embed-1b-v2
  • nvidia/llama-nemotron-rerank-1b-v2
    • examples/pooling/score/offline_using_template.py
    • examples/pooling/score/online_using_template.py

Purpose

This PR allows users to specify a custom prompt template for score/rerank models by providing the --chat-template CLI argument or setting chat_template in tokenizer_config.json.

Motivation: The current mechanism for setting custom score templates (SupportsScoreTemplate) is architecture-specific—it requires modifying the model class itself. This change decouples the prompt template from the model class, enabling support for any model requiring a custom score template without model-specific code changes.

Immediate use case: The nvidia/llama-nemotron-rerank-1b-v2 model, which uses Llama architecture, but with a custom score template, can now be made to run correctly on vLLM with minor config.json modifications.

Running nvidia/llama-nemotron-rerank-1b-v2 with examples provided in the model's README, using FP32 precision:

Running without the custom template:

vllm serve  nvidia/llama-nemotron-rerank-1b-v2  --runner pooling --dtype float32 --port 8000 --pooler-config '{"pooling_type": "MEAN"}'

Running with a custom template:

echo -ne 'question:{{ messages[0]["query"] }} \n \n passage:{{ messages[1]["query"] }}' > score_template.jinja
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --dtype float32 --port 8000 --pooler-config '{"pooling_type": "MEAN"}' --chat-template score_template.jinja

Without a custom template:

  Query: how much protein should a female eat?
  Document: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams...
  Score: 20.7482
  vllm Score: 6.0918
  Query: how much protein should a female eat?
  Document: Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top o...
  Score: -23.0923
  vllm Score: -7.7225
  Query: how much protein should a female eat?
  Document: Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the su...
  Score: -0.3436
  vllm Score: -1.1680

With a custom template:

  Query: how much protein should a female eat?
  Document: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams...
  Score: 20.7482
  vllm Score: 20.7482
  Query: how much protein should a female eat?
  Document: Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top o...
  Score: -23.0923
  vllm Score: -23.0923
  Query: how much protein should a female eat?
  Document: Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the su...
  Score: -0.3436
  vllm Score: -0.3436

Test Plan

tests/entrypoints/pooling/score/test_utils.py
tests/models/language/pooling_mteb_test/test_nemotron.py

Test Result

pass

TODO

  1. Template aware prompt truncation to avoid cutting off important instructions.

#30550 (comment)

since vllm don't allow truncation by default, it should not be a problem.

#30550 (comment)

  1. Step to standardize template scheme and inputs for reranking

#30550 (comment)
#30550 (comment)
#30550 (comment)
#30550 (comment)
#30550 (comment)

  1. score_template should be explicitly specified in sbert_config.json for example.

Attempting to use tokenizer_config.json templates would most likely break these models, as often they just inherit the template from the original LLM.
#30550 (comment)

  1. Template for embedding models

#30550 (comment)

  1. It's a bit confusing to mix chat_template and score_template at the moment in vllm code.

#30550 (comment)

  1. Add more models to Testing and Examples
  • bge_reranker_v2_gemma
  • mxbai_rerank
  • qwen3_reranker
  • nemotron_rerank √

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link

mergify bot commented Dec 12, 2025

Documentation preview: https://vllm--30550.org.readthedocs.build/en/30550/

@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Dec 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a --score-template CLI argument, allowing users to provide a custom Jinja2 template for score/rerank models. This is a valuable feature for decoupling prompt formatting from model-specific code. The implementation is mostly solid, with new CLI arguments, documentation, and tests. However, I've identified a high-severity issue related to code reuse that impacts maintainability and user experience. Specifically, chat-template-specific utilities are being reused for score templates, which can lead to confusing error messages. I've suggested a refactoring to create more generic template-handling functions.

@mergify
Copy link

mergify bot commented Dec 12, 2025

Hi @jzakrzew, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@jzakrzew jzakrzew force-pushed the score-template-cli-arg branch from 1f64fa9 to 9258e17 Compare December 12, 2025 12:22
@noooop noooop self-assigned this Dec 12, 2025
@noooop
Copy link
Collaborator

noooop commented Dec 12, 2025

@Samoed
Copy link
Contributor

Samoed commented Dec 12, 2025

Hi! We have a separate class for handling instruction-based models that process instructions, with an example for Qwen3. However, this approach is a bit naive, since there's no standard way of doing this yet. Maybe @tomaarsen has some thoughts on standardizing prompt templates for cross-encoders

For me, always unclear why there are no models that defines prompts in some jinja templates that could be used more automatically

@noooop
Copy link
Collaborator

noooop commented Dec 12, 2025

hello @tomaarsen

Please take a look at this thread.

@tomaarsen
Copy link

tomaarsen commented Dec 12, 2025

Thanks for pinging me @noooop & @Samoed.
In Sentence Transformers, I've been wanting to support the modern Causal-style rerankers, specifically:

As this modern format is becoming a lot more prevalent. For my codebase, there were always two main concerns:

  1. The CrossEncoder (a.k.a. reranker) class relies on the transformers AutoModelForSequenceClassification. Models loaded with this factory are actually rather similar to AutoModelForCausalLM models, but their classifier head predicts one class akin to regression, rather than predicting scores for all tokens in the vocabulary and 1) taking the score for yes or 1 or 2) taking the difference of the scores of yes and no or 1 and 0 (which is also why you can make https://huggingface.co/tomaarsen/Qwen3-Reranker-0.6B-seq-cls out of https://huggingface.co/Qwen/Qwen3-Reranker-0.6B).
  2. These models rely on a template, rather than the traditional text pairs input that was possible for AutoModelForSequenceClassification.

I think working on this support is so important that I'm working on a major refactor of the Sentence Transformers codebase, notably around the CrossEncoder, to help modularize it. This allows me to very easily support models that don't rely on AutoModelForSequenceClassification via built-in or custom modules that are executed sequentially. That solves concern 1 that I had, which is rather unrelated to this vLLM issue, but might give some more context to my recent work on the rerankers.

For concern 2, a very simple solution is to rely on the transformers chat_template, where the two texts passed to rerank are passed as two "messages" in a single text-chain. The format could then look something akin to:

<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: Given a web search query, retrieve relevant passages that answer the query
<Query>: {{ messages[0]["content"] }}
<Document>: {{ messages[1]["content"] }}<|im_end|>
<|im_start|>assistant
<think>\n\n</think>\n\n\n
[
    {
        "role": "query",
        "content": "What is the capital of France?",
    },
    {
        "role": "document",
        "content": "Paris is the capital of France.",
    }
]
tokenized = tokenizer.apply_chat_template(messages, ...)

(This one matches the format required for https://huggingface.co/Qwen/Qwen3-Reranker-0.6B, I believe)

An additional benefit here is that we can take advantage of a "system prompt" of sorts as an instruction/prompt for the reranker. In the above chat template, I hardcoded Given a web search query, retrieve relevant passages that answer the query, but that is the instruction that users could modify to match their use cases.

Some of the obvious advantages is that the transformers chat_template is pretty commonly used and known, that most LLM-based models would already have a chat-template copied from the base model that can very reasonably be updated (the chat-template for generation is useless on a reranker after all), etc.

But, my primary hesitation at the current stage is that apply_chat_template doesn't allow for any template-aware truncation. I suppose this is more often an issue on the LLM side, but the current template-unaware truncation will strip off crucial template tokens (e.g. <think>\n\n</think>\n\n\n), at which point the reranker scores become absolutely useless. This wouldn't be much of an issue if reranker models didn't impose any sequence length limits, but most do (e.g. 32k for Qwen3, 1024 for BAAI, 8k/32k for Mixedbread, etc.).

This tempts me to write a more "manual" templating implementation, where I can apply the truncation on the second input (often the 'document' in a query-document setting). Recurring issues that I've found with my initial attempts are that you can't fully separately tokenize the template from the actual texts, as many template tokens will want to "merge" with actual text tokens (e.g. having : Hello as a token in ... Query: {{ messages[0]["query"] }} ...)

Those are my thoughts for now. @noooop , where do you stand regarding:

  1. using transformers' chat_template as the source of truth for the templating?
  2. whether truncation limits are a problem, or whether you think we can ignore it? Your experience with LLMs that more frequently rely on chat templates might be useful there.
  • Tom Aarsen

@Samoed
Copy link
Contributor

Samoed commented Dec 12, 2025

using transformers' chat_template as the source of truth for the templating?

Generally I think this should be like it, but for now there are now models. Even Qwen just inherit template from original LLM.

"role": "document",

Generally I think good approach, but I'm afraid some libraries won't allow custom role names. Probably you can use name field for this, but this is a bit unintuitive too

@noooop
Copy link
Collaborator

noooop commented Dec 13, 2025

@DarkLight1337 @hmellor

What are your thoughts?


Those are my thoughts for now. @noooop , where do you stand regarding:

  1. using transformers' chat_template as the source of truth for the templating?
  2. whether truncation limits are a problem, or whether you think we can ignore it? Your experience with LLMs that more frequently rely on chat templates might be useful there.

This is also my concern, which is why I'd like to seek your advice. After all, Sentence Transformers and MTEB are upstream of vLLM, and vLLM only supports a very limited number of CrossEncoder models.


Just to mention

Currently, vLLM does not perform truncation by default, following the OpenAI API behavior for /v1/embeddings.

#24235 (comment)

@DarkLight1337
Copy link
Member

using transformers' chat_template as the source of truth for the templating?

This makes sense. If the HF Hub repo has an incorrect chat template, you can override in in vLLM via passing --chat-template.

whether truncation limits are a problem, or whether you think we can ignore it? Your experience with LLMs that more frequently rely on chat templates might be useful there.

As @noooop , since we don't allow truncation by default, it should not be a problem.

@jzakrzew
Copy link
Contributor Author

Ok, I'll modify the PR, so that it uses --chat-template and apply_hf_chat_template from vllm.entrypoints.chat_utils, instead of calling jinja2 directly.

@tomaarsen
Copy link

tomaarsen commented Dec 15, 2025

I think that's the right move. I'll also move to chat_template. We should aim for a bit of a format that works conveniently, e.g.

  1. Using "query", "document", and perhaps "prompt" for the roles of the messages perhaps? For reference, in Sentence Transformers/MTEB, models are called with nested lists like [["What is the capital of China?", "The capital of China is Beijing."], ...], so I have to know how to convert that to the messages structure.
  2. Or assume that message[0]["content"] is the query and message[1]["content"] is the document? This becomes very tricky if you want flexibility in the prompt/instruction as will be possible with Sentence Transformers (model.predict([["What is the capital of China?", "The capital of China is Beijing."], ...], prompt_name="default")). Index 0 for prompts, 1 for query, and 2 for the document is already stronger, but still too arbitrary in my opinion.

Does vLLM support prompts/instructions?

Edit: As mentioned by @Samoed, the above approaches are not very robust to Listwise rerankers which has multiple documents.

  • Tom Aarsen

@Samoed
Copy link
Contributor

Samoed commented Dec 15, 2025

Also, support based on chunk content can be added like

{
  "role": "user",
  "content": [
    {
      "type": "query",
      "text": {  # query/text
        "value": "How does AI work? Explain it in simple terms.",
        "annotations": []
      }
    },
    {
       "type": "document",
       "text": {  # document/text
        "value": "AI works like ...",
      }
  ],
}

But I'm not sure is it possible to handle from jinja and if it work with other libraries

For reference, in Sentence Transformers/MTEB, models are called with nested lists like [["What is the capital of China?", "The capital of China is Beijing."], ...],

By the way, this won't work for ListWise reranking and it's support should be added (created issue in mteb embeddings-benchmark/mteb#3744)

@Samoed
Copy link
Contributor

Samoed commented Dec 15, 2025

I think we need ask someone from tokenizers/chat template mainaters for better way to handle this

@DarkLight1337
Copy link
Member

cc @hmellor

@jzakrzew jzakrzew force-pushed the score-template-cli-arg branch from 9258e17 to 40808e9 Compare December 15, 2025 15:11
@jzakrzew jzakrzew changed the title [Frontend] Support passing custom score template as a CLI argument to vllm serve [Frontend] Support using chat template as custom score template for reranking models Dec 15, 2025
@mergify
Copy link

mergify bot commented Dec 15, 2025

Hi @jzakrzew, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@tomaarsen
Copy link

tomaarsen commented Dec 19, 2025

Yes. For context, transformers used to store the chat_template in its tokenizer_config.json. You can still see this on older models, e.g. https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/tokenizer_config.json#L2053
However, as this is very difficult to read properly, transformers started saving these in a file called chat_template.jinja instead (source). When you load a tokenizer with transformers, it'll pull from this file automatically, and you can use:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto", dtype=torch.bfloat16)

messages = [
    {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate",},
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
 ]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
print(tokenizer.decode(tokenized_chat[0]))
<|system|>
You are a friendly chatbot who always responds in the style of a pirate</s>
<|user|>
How many helicopters can a human eat in one sitting?</s>
<|assistant|>

Docs: https://huggingface.co/docs/transformers/main/en/chat_templating#using-applychattemplate

Sentence Transformers relies on transformers, and will soon use this transformers.apply_chat_template to convert messages into the correct format for modern rerankers.

  • Tom Aarsen

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thanks for the detailed discussion!

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
@vllm-project vllm-project deleted a comment from mergify bot Dec 22, 2025
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 22, 2025
@noooop noooop enabled auto-merge (squash) December 23, 2025 08:23
auto-merge was automatically disabled December 23, 2025 09:29

Head branch was pushed to by a user without write access

@jzakrzew
Copy link
Contributor Author

@noooop Sorry, just wanted to clarify one comment, I did not notice you enabled automerge.

@noooop noooop enabled auto-merge (squash) December 23, 2025 09:47
@noooop noooop merged commit 23daef5 into vllm-project:main Dec 23, 2025
58 checks passed
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…eranking models (vllm-project#30550)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Dec 30, 2025
…eranking models (vllm-project#30550)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…eranking models (vllm-project#30550)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend llama Related to Llama models new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants