Skip to content

Commit ecc7081

Browse files
committed
fix lut gemm
1 parent d5aa986 commit ecc7081

File tree

3 files changed

+165
-93
lines changed

3 files changed

+165
-93
lines changed

onnxruntime/core/mlas/lib/qlutgemm.cpp

Lines changed: 104 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,53 @@ Module Name:
2525
#include <memory>
2626
#include <string>
2727
#include <thread>
28+
#include <mutex>
2829
#include <unordered_map>
2930

30-
/** T-MAC GEMM kernel Config */
31+
/**
32+
* Global cache for T-MAC kernel parameters, indexed by configuration.
33+
* This map and its associated mutex ensure thread-safe parameter management
34+
* across concurrent MLAS calls.
35+
*/
3136
static std::unordered_map<std::string, struct MlasTMACKernelParams> tmac_kernel_configs;
37+
static std::mutex tmac_kernel_configs_mutex;
3238

33-
const MlasTMACKernelParams&
39+
static std::string
40+
GetTmacKey(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
41+
{
42+
// Generate a unique cache key based on the GEMM and quantization configuration.
43+
return std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" +
44+
std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
45+
}
46+
47+
MlasTMACKernelParams
3448
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
3549
{
36-
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
37-
if (tmac_kernel_configs.count(key)) {
38-
return tmac_kernel_configs[key];
50+
std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point);
51+
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
52+
auto it = tmac_kernel_configs.find(key);
53+
if (it != tmac_kernel_configs.end()) {
54+
return it->second;
3955
}
40-
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized");
56+
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized for key: " + key);
4157
}
4258

4359
void MLASCALL
4460
MlasClearLutGemmKernelConfig()
4561
{
62+
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
4663
tmac_kernel_configs.clear();
4764
}
4865

4966
void MLASCALL
5067
MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
5168
{
52-
std::string key = std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0");
53-
if (tmac_kernel_configs.count(key)) {
54-
return;
69+
std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point);
70+
{
71+
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
72+
if (tmac_kernel_configs.find(key) != tmac_kernel_configs.end()) {
73+
return;
74+
}
5575
}
5676

5777
MlasTMACKernelParams params;
@@ -121,7 +141,10 @@ MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size,
121141
params.has_zero_point = has_zero_point;
122142
params.one_scale = false; // TODO(vraspar): support one scale case for bitnet
123143

124-
tmac_kernel_configs[key] = params;
144+
{
145+
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
146+
tmac_kernel_configs[key] = params;
147+
}
125148
return;
126149
}
127150

@@ -222,53 +245,52 @@ LutGemmPackQuantBData(
222245
const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem);
223246
memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed?
224247

225-
MlasTrySimpleParallel(
226-
ThreadPool, Iterations,
227-
[&](ptrdiff_t tid) {
228-
size_t im = static_cast<size_t>(tid);
229-
for (size_t ib = 0; ib < bits; ib++) {
230-
for (size_t ik = 0; ik < K / g; ik++) {
231-
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
232-
size_t new_im = im / simd_n_out;
233-
size_t new_isno = im % simd_n_out;
234-
size_t new_ib = ib;
235-
size_t new_ik = ik;
236-
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;
237-
238-
// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
239-
new_im = new_idx / c1_nb0;
240-
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
241-
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
242-
new_ik = (new_idx % c1_nb2);
243-
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;
244-
245-
// # 0 1 2 3 4 5
246-
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
247-
new_im = new_idx / c2_nb0;
248-
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
249-
new_isni = (new_idx % c2_nb1) / c2_nb2;
250-
new_ing = (new_idx % c2_nb2) / c2_nb3;
251-
new_ik = (new_idx % c2_nb3) / c2_nb4;
252-
size_t new_ikf = (new_idx % c2_nb4);
253-
new_idx = new_im * c2_fac0 +
254-
new_ik * c2_fac1 +
255-
new_ibm * c2_fac2 +
256-
new_ikf * c2_fac3 +
257-
new_isni * ngroups_per_elem +
258-
new_ing;
259-
new_idx = new_idx / ngroups_per_elem;
260-
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
261-
uint8_t buf_val = buf[buf_idx];
262-
263-
// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
264-
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
265-
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
266-
(buf_val << (new_ing * g))
267-
);
268-
}
248+
// NOTE: The second packing loop is intentionally serialized to avoid data races.
249+
// T-MAC packs multiple output features (N) into a single byte if ngroups_per_elem > 1.
250+
// Parallelizing this across N would lead to concurrent bit-plane updates on the same memory location.
251+
for (size_t im = 0; im < Iterations; im++) {
252+
for (size_t ib = 0; ib < bits; ib++) {
253+
for (size_t ik = 0; ik < K / g; ik++) {
254+
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
255+
size_t new_im = im / simd_n_out;
256+
size_t new_isno = im % simd_n_out;
257+
size_t new_ib = ib;
258+
size_t new_ik = ik;
259+
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;
260+
261+
// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
262+
new_im = new_idx / c1_nb0;
263+
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
264+
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
265+
new_ik = (new_idx % c1_nb2);
266+
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;
267+
268+
// # 0 1 2 3 4 5
269+
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
270+
new_im = new_idx / c2_nb0;
271+
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
272+
new_isni = (new_idx % c2_nb1) / c2_nb2;
273+
new_ing = (new_idx % c2_nb2) / c2_nb3;
274+
new_ik = (new_idx % c2_nb3) / c2_nb4;
275+
size_t new_ikf = (new_idx % c2_nb4);
276+
new_idx = new_im * c2_fac0 +
277+
new_ik * c2_fac1 +
278+
new_ibm * c2_fac2 +
279+
new_ikf * c2_fac3 +
280+
new_isni * ngroups_per_elem +
281+
new_ing;
282+
new_idx = new_idx / ngroups_per_elem;
283+
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
284+
uint8_t buf_val = buf[buf_idx];
285+
286+
// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
287+
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
288+
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
289+
(buf_val << (new_ing * g))
290+
);
269291
}
270292
}
271-
);
293+
}
272294
}
273295

274296
// Internal helper: calculates packed scales and zero points size in floats
@@ -472,16 +494,15 @@ size_t
472494
CalculateLutBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params)
473495
{
474496
MLAS_UNREFERENCED_PARAMETER(n);
475-
constexpr size_t kAllockAligment = 64;
476497
const size_t lut_scales_size = k / tmac_params.act_group_size;
477498

478-
size_t wsize = k * m * 4 * sizeof(int8_t); // 4 bytes per k element for 2-bit LUT
479-
wsize += lut_scales_size * m * 2 * sizeof(float); // scales + biases
480-
481-
wsize = ((wsize - 1) / kAllockAligment + 1) * kAllockAligment;
499+
// The AVX2 kernel (g=4) expects 16 entries (16 bytes) per group of 4 activations.
500+
// This effectively requires 4 bytes per activation in the K dimension.
501+
size_t lut_size_bytes = m * k * 4;
502+
size_t scales_size_bytes = m * lut_scales_size * sizeof(float);
503+
size_t biases_size_bytes = m * lut_scales_size * sizeof(float);
482504

483-
// TODO(vrapar): add temp buffer for FP16
484-
return wsize;
505+
return lut_size_bytes + scales_size_bytes + biases_size_bytes + 256; // + alignment/safety padding
485506
}
486507

487508
void MLASCALL
@@ -532,17 +553,23 @@ MlasLutGemm(
532553
// n_tiles_num = m * bits / bm;
533554

534555
// TODO(vraspar): support other bitwidths
556+
// For T-MAC, kernel properties (bm, n_tiles_num) are primarily driven by the number of output features (N).
557+
// Initialization during packing (LutGemmPackQuantBDataSize) uses N as the major dimension,
558+
// so we must match that here to ensure consistent weight tiling.
559+
MlasInitLutGemmKernelConfig(N, K, 2, BlkLen, HasZeroPoint);
535560
const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, 2, BlkLen, HasZeroPoint);
536561
const size_t lut_scales_size = K / tmac_params.act_group_size;
562+
const size_t lut_size_bytes = static_cast<size_t>(M) * static_cast<size_t>(K) * 4;
537563
size_t lut_buffer_size = CalculateLutBufferSize(N, K, M, tmac_params);
538564

539565
// make buffer of lut_buffer_size bytes
540566
// TODO(vraspar): other way to do it
541567
auto lut_buffer = std::make_unique<int8_t[]>(lut_buffer_size);
568+
memset(lut_buffer.get(), 0, lut_buffer_size);
542569

543570
int8_t* qlut = reinterpret_cast<int8_t*>(lut_buffer.get());
544-
float* lut_scales = reinterpret_cast<float*>(qlut + K * M * 4); // after lut
545-
float* lut_biases = reinterpret_cast<float*>(lut_scales + lut_scales_size * M); // after scales
571+
float* lut_scales = reinterpret_cast<float*>(qlut + lut_size_bytes); // after lut
572+
float* lut_biases = reinterpret_cast<float*>(lut_scales + lut_scales_size * M); // after scales
546573

547574
const auto* a_float = reinterpret_cast<const float*>(A); // Activation data
548575

@@ -558,11 +585,9 @@ MlasLutGemm(
558585

559586
for (size_t ine11 = 0; ine11 < static_cast<size_t>(M); ine11++) {
560587
const size_t row_offset = ine11 * K;
561-
const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT
562-
const size_t scale_bias_offset = ine11 * lut_scales_size;
563-
564-
// Call the dispatch function for this row
565-
// ggml_tmac_mul_mat_task_init
588+
// Call the LUT generation kernel for this activation row.
589+
// We use a 4-byte stride (per activation) for the LUT entries to satisfy
590+
// the memory layout requirements of the computation kernel.
566591
Dispatch->GenerateLUT(
567592
const_cast<float*>(a_float + row_offset), // Input activation for this row
568593
qlut + lut_offset, // Output LUT for this row
@@ -571,7 +596,8 @@ MlasLutGemm(
571596
M,
572597
K,
573598
N,
574-
tmac_params.act_group_size
599+
tmac_params.act_group_size,
600+
tmac_params.act_group_size * 4
575601
);
576602
}
577603

@@ -657,15 +683,17 @@ MlasLutGemm(
657683

658684
// Process all batch items in this chunk
659685
for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) {
660-
// Calculate LUT offsets for this batch item
686+
// Calculate LUT offsets with 4-byte stride (per activation) for consistent access.
661687
const size_t qlut_offset = K * ine11 * 4;
662688
const size_t lut_scales_offset = lut_scales_size * ine11;
663689

664690
// Calculate output offset
665691
const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0;
666692

667-
// Call the dispatch function to compute this tile
668-
// Note M and N are swapped in TMAC terminology
693+
// Call the dispatch function to compute this tile.
694+
// We pass one batch item at a time (M=1) and ChunkSize0 output features.
695+
// TotalN is passed specifically to allow the kernel to find the correct
696+
// parameters (bm, tiles) used during weight packing.
669697
Dispatch->ComputeGemm(
670698
packed_weights + w_offset, // Weight tile
671699
QuantBScale + scales_offset, // Weight scales for this tile
@@ -674,8 +702,9 @@ MlasLutGemm(
674702
lut_biases + lut_scales_offset, // LUT biases
675703
act_output + dst_offset, // Output location
676704
static_cast<int>(K), // K dimension
677-
static_cast<int>(N), // N dimension
678-
static_cast<int>(1), // M dimension (processing one batch item at a time)
705+
static_cast<int>(1), // M dimension (batch size = 1)
706+
static_cast<int>(ir0_end - ir0_start), // N dimension (output features in chunk)
707+
static_cast<int>(N), // TotalN (total output features in weights)
679708
BlkLen, // Weight quantization group size
680709
HasZeroPoint // Whether zero points are used
681710
);

onnxruntime/core/mlas/lib/qlutgemm.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ struct MlasTMACKernelParams {
4242
bool one_scale;
4343
};
4444

45-
const MlasTMACKernelParams&
45+
MlasTMACKernelParams
46+
/**
47+
* Retrieves the T-MAC kernel configuration for a given GEMM problem.
48+
* Returns the parameters by value to ensure thread-safety across concurrent calls.
49+
*/
50+
MlasTMACKernelParams
4651
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point);
4752

4853
typedef void(MLAS_QNBIT_GEMM_LUT_GEN)(
@@ -53,19 +58,21 @@ typedef void(MLAS_QNBIT_GEMM_LUT_GEN)(
5358
size_t M,
5459
size_t K,
5560
size_t N,
56-
size_t act_group_size
61+
size_t act_group_size,
62+
size_t lut_stride // Stride (in bytes) between consecutive LUT entries along the batch dimension.
5763
);
5864

5965
typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)(
60-
const uint8_t* weights,
61-
const float* scales,
66+
const uint8_t* A,
67+
const float* Scales,
6268
const int8_t* LUT,
6369
const float* LUT_Scales,
6470
const float* LUT_Biases,
6571
float* C,
6672
int K,
67-
int M, // batch size (number of rows in activation)
68-
int N,
73+
int M, // Batch size (current activation rows).
74+
int N, // Number of output features to compute in this tile/chunk.
75+
int TotalN, // Total number of output features in the weights (used for parameter mapping).
6976
size_t BlkLen,
7077
bool HasZeroPoint
7178
);

0 commit comments

Comments
 (0)