Skip to content
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
39db621
Update
manuelcandales Jan 31, 2026
0ed7c5c
Update
manuelcandales Jan 31, 2026
b4310cc
Update
manuelcandales Jan 31, 2026
94c823c
Update
manuelcandales Jan 31, 2026
31b6f45
Update
manuelcandales Feb 2, 2026
c68cc6b
Update
manuelcandales Feb 2, 2026
bd7192f
Update
manuelcandales Feb 2, 2026
bcc8bda
Update
manuelcandales Feb 2, 2026
f166c50
Update
manuelcandales Feb 2, 2026
0834659
Update
manuelcandales Feb 2, 2026
ed4dcee
Update
manuelcandales Feb 2, 2026
a058197
Update
manuelcandales Feb 2, 2026
7146282
Update
manuelcandales Feb 2, 2026
d3501af
Update
manuelcandales Feb 2, 2026
fe5be37
Update
manuelcandales Feb 2, 2026
a0e3469
Update
manuelcandales Feb 2, 2026
fcfa832
Update
manuelcandales Feb 2, 2026
2e50286
Update
manuelcandales Feb 2, 2026
0145613
Update
manuelcandales Feb 2, 2026
2e3254a
Update
manuelcandales Feb 3, 2026
c5a3c1a
Update
manuelcandales Feb 3, 2026
457428b
Update
manuelcandales Feb 3, 2026
fec15bc
Update
manuelcandales Feb 3, 2026
40ec415
Update
manuelcandales Feb 3, 2026
c16dc59
Update
manuelcandales Feb 4, 2026
646b4b3
Update
manuelcandales Feb 5, 2026
3483dbf
Update
manuelcandales Feb 5, 2026
310b1b6
Update
manuelcandales Feb 5, 2026
6ad4556
Update
manuelcandales Feb 5, 2026
7e422e2
Update
manuelcandales Feb 5, 2026
1ae26f5
Update
manuelcandales Feb 5, 2026
086e05c
Update
manuelcandales Feb 5, 2026
11da547
Update
manuelcandales Feb 5, 2026
5ba588f
Update
manuelcandales Feb 5, 2026
0bfe7a5
Update
manuelcandales Feb 5, 2026
099bfd3
Update
manuelcandales Feb 5, 2026
7ee1d30
Update
manuelcandales Feb 5, 2026
3655f63
Update
manuelcandales Feb 5, 2026
b8f2f15
Update
manuelcandales Feb 5, 2026
5b0e27a
Update
manuelcandales Feb 5, 2026
b24150a
Update
manuelcandales Feb 5, 2026
c5b3b22
Update
manuelcandales Feb 5, 2026
f8022a9
Update
manuelcandales Feb 5, 2026
4edd311
Update
manuelcandales Feb 5, 2026
3a99b50
Update
manuelcandales Feb 5, 2026
dfb85b1
Update
manuelcandales Feb 5, 2026
9fa2428
Update
manuelcandales Feb 5, 2026
678ebf1
Update
manuelcandales Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 92 additions & 23 deletions backends/apple/metal/runtime/shims/et_metal_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -2492,50 +2492,117 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight(

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Converted tensor handles to ET tensors");

// Log tensor shapes for debugging
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A shape: [%d, %d], strides: [%d, %d]",
a_tensor->dim() > 0 ? (int)a_tensor->sizes()[0] : 0,
a_tensor->dim() > 1 ? (int)a_tensor->sizes()[1] : 0,
a_tensor->dim() > 0 ? (int)a_tensor->strides()[0] : 0,
a_tensor->dim() > 1 ? (int)a_tensor->strides()[1] : 0);

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B shape: [%d, %d]",
b_tensor->dim() > 0 ? (int)b_tensor->sizes()[0] : 0,
b_tensor->dim() > 1 ? (int)b_tensor->sizes()[1] : 0);

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S shape: [%d, %d], Z shape: [%d, %d]",
s_tensor->dim() > 0 ? (int)s_tensor->sizes()[0] : 0,
s_tensor->dim() > 1 ? (int)s_tensor->sizes()[1] : 0,
z_tensor->dim() > 0 ? (int)z_tensor->sizes()[0] : 0,
z_tensor->dim() > 1 ? (int)z_tensor->sizes()[1] : 0);

// Validate tensor dimensions
// Validate A tensor: ndim, dtype, contiguity
if (a_tensor->dim() != 2) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor must be 2-D, got %d", (int)a_tensor->dim());
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 2D tensor, got %d", (int)a_tensor->dim());
return Error::InvalidArgument;
}
auto a_dtype = a_tensor->scalar_type();
if (a_dtype != exec_aten::ScalarType::Float &&
a_dtype != exec_aten::ScalarType::BFloat16) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 32-bit or 16-bit float tensor, got dtype %d", (int)a_dtype);
return Error::InvalidArgument;
}
// Check A is contiguous (stride[1] == 1 and stride[0] == size[1])
if (a_tensor->strides()[1] != 1 || a_tensor->strides()[0] != a_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be contiguous, strides=[%lld, %lld]",
(long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]);
return Error::InvalidArgument;
}


// Validate B tensor: ndim, dtype (uint8), contiguity
if (b_tensor->dim() != 2) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor must be 2-D, got %d", (int)b_tensor->dim());
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be 2D tensor, got %d", (int)b_tensor->dim());
return Error::InvalidArgument;
}
if (b_tensor->scalar_type() != exec_aten::ScalarType::Byte) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be uint8 tensor, got dtype %d", (int)b_tensor->scalar_type());
return Error::InvalidArgument;
}
// Check B is contiguous
if (b_tensor->strides()[1] != 1 || b_tensor->strides()[0] != b_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be contiguous, strides=[%lld, %lld]",
(long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]);
return Error::InvalidArgument;
}

// Get dimensions: A is [M, K], B is [N, K/2] (4-bit packed, 2 values per byte)
int32_t M = static_cast<int32_t>(a_tensor->sizes()[0]);
int32_t K = static_cast<int32_t>(a_tensor->sizes()[1]);
int32_t N = static_cast<int32_t>(b_tensor->sizes()[0]);
constexpr int nbit = 4;

ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: M=%d, K=%d, N=%d, group_size=%lld", M, K, N, group_size);

// Validate alignment requirements
// B.size(1) should be (K / 8) * nbit for 4-bit packing
int64_t expected_b_size1 = (K / 8) * nbit;
if (b_tensor->sizes()[1] != expected_b_size1) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B.size(1) == %lld, got %lld",
(long long)expected_b_size1, (long long)b_tensor->sizes()[1]);
return Error::InvalidArgument;
}

// Validate K alignment
if (K % 8 != 0) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: K (%d) must be divisible by 8", K);
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect K to be multiple of 8, got %d", K);
return Error::InvalidArgument;
}

// Validate N alignment
if (N % 4 != 0) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: N (%d) must be divisible by 4", N);
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4, got M=%d, N=%d", M, N);
return Error::InvalidArgument;
}

// Validate S tensor: 2D with S.size(0) == N, contiguous, dtype matches A
if (s_tensor->dim() != 2 || s_tensor->sizes()[0] != N) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld",
N, (int)s_tensor->dim(), (long long)s_tensor->sizes()[0]);
return Error::InvalidArgument;
}
if (s_tensor->scalar_type() != a_dtype) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S dtype to match A dtype %d, got %d",
(int)a_dtype, (int)s_tensor->scalar_type());
return Error::InvalidArgument;
}
if (s_tensor->strides()[1] != 1 || s_tensor->strides()[0] != s_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be contiguous, strides=[%lld, %lld]",
(long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]);
return Error::InvalidArgument;
}

// Validate Z tensor: 2D with Z.size(0) == N, contiguous, dtype matches A
if (z_tensor->dim() != 2 || z_tensor->sizes()[0] != N) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld",
N, (int)z_tensor->dim(), (long long)z_tensor->sizes()[0]);
return Error::InvalidArgument;
}
if (z_tensor->scalar_type() != a_dtype) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z dtype to match A dtype %d, got %d",
(int)a_dtype, (int)z_tensor->scalar_type());
return Error::InvalidArgument;
}
if (z_tensor->strides()[1] != 1 || z_tensor->strides()[0] != z_tensor->sizes()[1]) {
ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be contiguous, strides=[%lld, %lld]",
(long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]);
return Error::InvalidArgument;
}

// Log shapes and strides for all tensors
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)a_tensor->sizes()[0], (long long)a_tensor->sizes()[1],
(long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]);
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)b_tensor->sizes()[0], (long long)b_tensor->sizes()[1],
(long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]);
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)s_tensor->sizes()[0], (long long)s_tensor->sizes()[1],
(long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]);
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Z tensor shape=[%lld, %lld], strides=[%lld, %lld]",
(long long)z_tensor->sizes()[0], (long long)z_tensor->sizes()[1],
(long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]);

// Determine data type
int32_t dtype = static_cast<int32_t>(a_tensor->scalar_type());
size_t element_size;
Expand Down Expand Up @@ -2652,6 +2719,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight(
// Dispatch based on kernel type (matching torchao dispatch patterns)
if (use_qmv_fast) {
// dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1)
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str());
kernel_func->dispatchThreadgroups(
M, // gridX
(N + 7) / 8, // gridY
Expand All @@ -2661,6 +2729,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight(
1); // threadsZ
} else {
// dispatch_mm_Mr1xNr4_per_TG: dispatchThreads with grid (N/4 * 32, 1, M), group (32, 1, 1)
ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str());
uint64_t grid_dims[3] = {static_cast<uint64_t>(N / 4 * 32), 1, static_cast<uint64_t>(M)};
uint64_t group_dims[3] = {32, 1, 1};
kernel_func->dispatchArrayWithGroupSize(grid_dims, 3, group_dims, 3);
Expand Down
Loading