Skip to content

Commit 9c4be47

Browse files
authored
Optional weight tying for Qwen3 and Llama3.2 pretraining (#949)
* optional weight tying for Qwen3 and Llama3.2 * typo
1 parent e0dbec3 commit 9c4be47

File tree

7 files changed

+17
-9
lines changed

7 files changed

+17
-9
lines changed

ch05/07_gpt_to_llama/standalone-llama32.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@
358358
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
359359
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
360360
"\n",
361-
" # Reusuable utilities\n",
361+
" # Reusable utilities\n",
362362
" cos, sin = compute_rope_params(\n",
363363
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
364364
" theta_base=cfg[\"rope_base\"],\n",

ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@
432432
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
433433
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
434434
"\n",
435-
" # Reusuable utilities\n",
435+
" # Reusable utilities\n",
436436
" if cfg[\"head_dim\"] is None:\n",
437437
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
438438
" else:\n",

ch05/11_qwen3/standalone-qwen3-moe.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@
422422
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
423423
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
424424
"\n",
425-
" # Reusuable utilities\n",
425+
" # Reusable utilities\n",
426426
" if cfg[\"head_dim\"] is None:\n",
427427
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
428428
" else:\n",

ch05/11_qwen3/standalone-qwen3.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@
388388
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
389389
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
390390
"\n",
391-
" # Reusuable utilities\n",
391+
" # Reusable utilities\n",
392392
" if cfg[\"head_dim\"] is None:\n",
393393
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
394394
" else:\n",

ch05/14_ch05_with_other_llms/ch05-llama32.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
},
114114
{
115115
"cell_type": "code",
116-
"execution_count": 3,
116+
"execution_count": null,
117117
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
118118
"metadata": {
119119
"colab": {
@@ -329,7 +329,11 @@
329329
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
330330
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
331331
"\n",
332-
" # Reusuable utilities\n",
332+
" # Uncomment the following code to tie weights\n",
333+
" # self.out_head.weight = self.tok_emb.weight\n",
334+
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
335+
"\n",
336+
" # Reusable utilities\n",
333337
" cos, sin = compute_rope_params(\n",
334338
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
335339
" theta_base=cfg[\"rope_base\"],\n",

ch05/14_ch05_with_other_llms/ch05-qwen3.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
},
122122
{
123123
"cell_type": "code",
124-
"execution_count": 3,
124+
"execution_count": null,
125125
"id": "86000d74-624a-48f0-86da-f41926cb9e04",
126126
"metadata": {
127127
"colab": {
@@ -332,7 +332,11 @@
332332
" self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
333333
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
334334
"\n",
335-
" # Reusuable utilities\n",
335+
" # Uncomment the following code to tie weights\n",
336+
" # self.out_head.weight = self.tok_emb.weight\n",
337+
" # torch.nn.init.normal_(self.out_head.weight, mean=0.0, std=0.02)\n",
338+
"\n",
339+
" # Reusable utilities\n",
336340
" if cfg[\"head_dim\"] is None:\n",
337341
" head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
338342
" else:\n",

pkg/llms_from_scratch/kv_cache/llama3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, cfg):
6565
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
6666
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
6767

68-
# Reusuable utilities
68+
# Reusable utilities
6969
cos, sin = compute_rope_params(
7070
head_dim=cfg["emb_dim"] // cfg["n_heads"],
7171
theta_base=cfg["rope_base"],

0 commit comments

Comments
 (0)