Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
39db621
Update
manuelcandales Jan 31, 2026
0ed7c5c
Update
manuelcandales Jan 31, 2026
b4310cc
Update
manuelcandales Jan 31, 2026
94c823c
Update
manuelcandales Jan 31, 2026
31b6f45
Update
manuelcandales Feb 2, 2026
c68cc6b
Update
manuelcandales Feb 2, 2026
bd7192f
Update
manuelcandales Feb 2, 2026
bcc8bda
Update
manuelcandales Feb 2, 2026
f166c50
Update
manuelcandales Feb 2, 2026
0834659
Update
manuelcandales Feb 2, 2026
ed4dcee
Update
manuelcandales Feb 2, 2026
a058197
Update
manuelcandales Feb 2, 2026
7146282
Update
manuelcandales Feb 2, 2026
d3501af
Update
manuelcandales Feb 2, 2026
fe5be37
Update
manuelcandales Feb 2, 2026
a0e3469
Update
manuelcandales Feb 2, 2026
fcfa832
Update
manuelcandales Feb 2, 2026
2e50286
Update
manuelcandales Feb 2, 2026
0145613
Update
manuelcandales Feb 2, 2026
2e3254a
Update
manuelcandales Feb 3, 2026
c5a3c1a
Update
manuelcandales Feb 3, 2026
457428b
Update
manuelcandales Feb 3, 2026
fec15bc
Update
manuelcandales Feb 3, 2026
40ec415
Update
manuelcandales Feb 3, 2026
c16dc59
Update
manuelcandales Feb 4, 2026
8ee7d60
Update
manuelcandales Feb 4, 2026
9966d37
Update
manuelcandales Feb 4, 2026
646b4b3
Update
manuelcandales Feb 5, 2026
3483dbf
Update
manuelcandales Feb 5, 2026
310b1b6
Update
manuelcandales Feb 5, 2026
6ad4556
Update
manuelcandales Feb 5, 2026
7e422e2
Update
manuelcandales Feb 5, 2026
1ae26f5
Update
manuelcandales Feb 5, 2026
086e05c
Update
manuelcandales Feb 5, 2026
9cede1e
Update
manuelcandales Feb 5, 2026
4149007
Update
manuelcandales Feb 5, 2026
11da547
Update
manuelcandales Feb 5, 2026
5ba588f
Update
manuelcandales Feb 5, 2026
0bfe7a5
Update
manuelcandales Feb 5, 2026
099bfd3
Update
manuelcandales Feb 5, 2026
7ee1d30
Update
manuelcandales Feb 5, 2026
3655f63
Update
manuelcandales Feb 5, 2026
a3a8aca
Update
manuelcandales Feb 5, 2026
f4203c8
Update
manuelcandales Feb 5, 2026
b8f2f15
Update
manuelcandales Feb 5, 2026
5b0e27a
Update
manuelcandales Feb 5, 2026
b24150a
Update
manuelcandales Feb 5, 2026
c5b3b22
Update
manuelcandales Feb 5, 2026
6b6398c
Update
manuelcandales Feb 5, 2026
2b0adc6
Update
manuelcandales Feb 5, 2026
f8022a9
Update
manuelcandales Feb 5, 2026
4edd311
Update
manuelcandales Feb 5, 2026
3a99b50
Update
manuelcandales Feb 5, 2026
77868a6
Update
manuelcandales Feb 5, 2026
b0711b5
Update
manuelcandales Feb 5, 2026
dfb85b1
Update
manuelcandales Feb 5, 2026
9fa2428
Update
manuelcandales Feb 5, 2026
53a60b4
Update
manuelcandales Feb 5, 2026
59f1b9d
Update
manuelcandales Feb 5, 2026
678ebf1
Update
manuelcandales Feb 5, 2026
07ee8dd
Update
manuelcandales Feb 5, 2026
e0a0a59
Update
manuelcandales Feb 5, 2026
8881ce9
Update
manuelcandales Feb 5, 2026
7701e98
Update
manuelcandales Feb 5, 2026
61dc1e2
Update
manuelcandales Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/apple/metal/runtime/shims/et_metal_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ inline U qdot_safe(

// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int in_vec_size_g = (in_vec_size + group_size - 1) / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
Expand Down Expand Up @@ -1283,8 +1283,8 @@ inline U qdot_safe(

U s = sl[0];
U b = bl[0];
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
}
}

Expand Down
50 changes: 50 additions & 0 deletions backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,56 @@ def forward(self, x: torch.Tensor):
}


# -------------------------------------------------------------------------
class LinearInt4_QMV_IMPL_small_odd(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 3, bias=True)

def forward(self, x: torch.Tensor):
return self.linear(x)


MODULE_REGISTRY["linear_int4_qmv_impl_small_odd"] = {
"model_class": LinearInt4_QMV_IMPL_small_odd,
"input_shapes": [(1, 8)],
"description": "Linear int4 quantization dispatching to qmv_impl",
"qlinear": "fpa4w",
"qlinear_group_size": 32,
"compare_to_unquantized": False,
"atol_float32": 5e-2,
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


# -------------------------------------------------------------------------
class LinearInt4_QMV_IMPL_small_even(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 10, bias=True)

def forward(self, x: torch.Tensor):
return self.linear(x)


MODULE_REGISTRY["linear_int4_qmv_impl_small_even"] = {
"model_class": LinearInt4_QMV_IMPL_small_even,
"input_shapes": [(1, 8)],
"description": "Linear int4 quantization dispatching to qmv_impl",
"qlinear": "fpa4w",
"qlinear_group_size": 32,
"compare_to_unquantized": False,
"atol_float32": 5e-2,
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


# -------------------------------------------------------------------------
# Convolution Modules
# -------------------------------------------------------------------------
Expand Down
Loading