Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/basic-tests-linux-uv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,4 @@ jobs:
shell: bash
run: |
source .venv/bin/activate
uv pip install transformers
pytest pkg/llms_from_scratch/tests/
12 changes: 12 additions & 0 deletions ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ class RoPEConfig:
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rope_parameters = {"rope_type": "default", "rope_theta": theta_base}

def standardize_rope_params(self):
return

config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
Expand Down Expand Up @@ -242,6 +246,10 @@ class RoPEConfig:
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rope_parameters = {"rope_type": "default", "rope_theta": theta_base}

def standardize_rope_params(self):
return

config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
Expand Down Expand Up @@ -320,6 +328,10 @@ class RoPEConfig:
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rope_parameters = {**hf_rope_params, "rope_theta": rope_theta}

def standardize_rope_params(self):
return

config = RoPEConfig()

Expand Down
21 changes: 20 additions & 1 deletion pkg/llms_from_scratch/tests/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,25 @@ class RoPEConfig:
hidden_size = head_dim * num_heads
num_attention_heads = num_heads

def __init__(self):
# Transformers >=5.0.0 expects `rope_parameters` on the instance.
self.rope_parameters = {**hf_rope_params, "rope_theta": rope_theta}

def standardize_rope_params(self):
params = dict(getattr(self, "rope_parameters", {}) or {})
if "rope_type" not in params:
params["rope_type"] = getattr(self, "rope_type", "default")
if "rope_theta" not in params:
params["rope_theta"] = getattr(self, "rope_theta")
# Handle older key name used in this repo.
if (
"original_max_position_embeddings" not in params
and "original_context_length" in params
):
params["original_max_position_embeddings"] = params["original_context_length"]
self.rope_parameters = params
return params

config = RoPEConfig()

rot_emb = LlamaRotaryEmbedding(config=config)
Expand Down Expand Up @@ -304,4 +323,4 @@ def test_llama3_base_equivalence_with_transformers():
ours_logits = ours(x)
theirs_logits = theirs(x).logits.to(ours_logits.dtype)

torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
57 changes: 53 additions & 4 deletions pkg/llms_from_scratch/tests/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import shutil
import tempfile
import platform
from collections.abc import Mapping
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -59,6 +60,36 @@ def extra_repr(self):
transformers_installed = importlib.util.find_spec("transformers") is not None


def _hf_ids(obj):
"""Normalize HF chat-template outputs across Transformers versions."""
if isinstance(obj, Mapping):
if "input_ids" in obj:
obj = obj["input_ids"]
elif "ids" in obj:
obj = obj["ids"]
elif hasattr(obj, "keys") and hasattr(obj, "__getitem__"):
# Some HF containers behave like mappings but don't register as Mapping.
try:
if "input_ids" in obj:
obj = obj["input_ids"]
elif "ids" in obj:
obj = obj["ids"]
except Exception:
pass
if hasattr(obj, "input_ids"):
obj = obj.input_ids
if hasattr(obj, "ids"):
obj = obj.ids
if isinstance(obj, torch.Tensor):
obj = obj.tolist()
if isinstance(obj, tuple):
obj = list(obj)
# Some HF versions return a batched structure even for a single prompt.
if isinstance(obj, list) and obj and isinstance(obj[0], list) and len(obj) == 1:
obj = obj[0]
return list(obj)


@pytest.fixture
def dummy_input():
torch.manual_seed(123)
Expand Down Expand Up @@ -211,14 +242,28 @@ def test_rope(context_len):

# Generate reference RoPE via HF
class RoPEConfig:
rope_type = "qwen3"
# Transformers' RoPE init map does not include "qwen3".
rope_type = "default"
factor = 1.0
dim: int = head_dim
rope_theta = 1_000_000
max_position_embeddings = context_len
hidden_size = head_dim * num_heads
num_attention_heads = num_heads

def __init__(self):
# Transformers >=5.0.0 expects `rope_parameters` on the instance.
self.rope_parameters = {"rope_type": "default", "rope_theta": rope_theta, "factor": 1.0}

def standardize_rope_params(self):
params = dict(getattr(self, "rope_parameters", {}) or {})
if "rope_type" not in params:
params["rope_type"] = getattr(self, "rope_type", "default")
if "rope_theta" not in params:
params["rope_theta"] = getattr(self, "rope_theta")
self.rope_parameters = params
return params

config = RoPEConfig()

rot_emb = Qwen3RotaryEmbedding(config=config)
Expand Down Expand Up @@ -495,20 +540,21 @@ def test_chat_wrap_and_equivalence(add_gen, add_think):

# Our encode vs HF template
ours = qt.encode(prompt)
ref = hf_tok.apply_chat_template(
ref = _hf_ids(hf_tok.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=add_gen,
enable_thinking=add_think,
)
))

if add_gen and not add_think:
pass # skip edge case as this is not something we use in practice
else:
assert ours == ref, (repo_id, add_gen, add_think)

# Round-trip decode equality
assert qt.decode(ours) == hf_tok.decode(ref)
if not (add_gen and not add_think):
assert qt.decode(ours) == hf_tok.decode(ref)

# EOS/PAD parity
assert qt.eos_token_id == hf_tok.eos_token_id
Expand Down Expand Up @@ -547,6 +593,7 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
messages, tokenize=True,
add_generation_prompt=add_gen, enable_thinking=add_think
)
ref_ids = _hf_ids(ref_ids)
ref_text = hf_tok.apply_chat_template(
messages, tokenize=False,
add_generation_prompt=add_gen, enable_thinking=add_think
Expand Down Expand Up @@ -611,6 +658,7 @@ def test_tokenizer_equivalence():
add_generation_prompt=states[0],
enable_thinking=states[1],
)
input_token_ids_ref = _hf_ids(input_token_ids_ref)
else:
input_token_ids_ref = input_token_ids

Expand Down Expand Up @@ -665,6 +713,7 @@ def test_multiturn_prefix_stability(repo_id, tok_file, add_gen, add_think):
running, tokenize=True,
add_generation_prompt=add_gen, enable_thinking=add_think
)
ref_ids = _hf_ids(ref_ids)
ref_text = hf_tok.apply_chat_template(
running, tokenize=False,
add_generation_prompt=add_gen, enable_thinking=add_think
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ dependencies = [
"torch>=2.2.2; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version <= '3.12'",
"torch>=2.2.2; sys_platform == 'linux' and python_version <= '3.12'",
"torch>=2.2.2; sys_platform == 'win32' and python_version <= '3.12'",

"tensorflow>=2.16.2; sys_platform == 'darwin' and platform_machine == 'x86_64'",
"tensorflow>=2.18.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
"tensorflow>=2.18.0; sys_platform == 'linux'",
"tensorflow>=2.18.0; sys_platform == 'win32'",

"jupyterlab>=4.0",
"tiktoken>=0.5.1",
"matplotlib>=3.7.1",
Expand Down Expand Up @@ -53,7 +51,7 @@ bonus = [
"sentencepiece>=0.1.99",
"thop",
"tokenizers>=0.21.1",
"transformers>=4.33.2",
"transformers>=5.0.0",
"tqdm>=4.65.0",
]

Expand Down