@@ -25,33 +25,53 @@ Module Name:
2525#include < memory>
2626#include < string>
2727#include < thread>
28+ #include < mutex>
2829#include < unordered_map>
2930
30- /* * T-MAC GEMM kernel Config */
31+ /* *
32+ * Global cache for T-MAC kernel parameters, indexed by configuration.
33+ * This map and its associated mutex ensure thread-safe parameter management
34+ * across concurrent MLAS calls.
35+ */
3136static std::unordered_map<std::string, struct MlasTMACKernelParams > tmac_kernel_configs;
37+ static std::mutex tmac_kernel_configs_mutex;
3238
33- const MlasTMACKernelParams&
39+ static std::string
40+ GetTmacKey (size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
41+ {
42+ // Generate a unique cache key based on the GEMM and quantization configuration.
43+ return std::to_string (M) + " _" + std::to_string (N) + " _" + std::to_string (nbits) + " _" +
44+ std::to_string (block_size) + " _" + (has_zero_point ? " 1" : " 0" );
45+ }
46+
47+ MlasTMACKernelParams
3448MlasGetLutGemmKernelParams (size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
3549{
36- std::string key = std::to_string (M) + " _" + std::to_string (N) + " _" + std::to_string (nbits) + " _" + std::to_string (block_size) + " _" + (has_zero_point ? " 1" : " 0" );
37- if (tmac_kernel_configs.count (key)) {
38- return tmac_kernel_configs[key];
50+ std::string key = GetTmacKey (M, N, nbits, block_size, has_zero_point);
51+ std::lock_guard<std::mutex> lock (tmac_kernel_configs_mutex);
52+ auto it = tmac_kernel_configs.find (key);
53+ if (it != tmac_kernel_configs.end ()) {
54+ return it->second ;
3955 }
40- MLAS_THROW_EX (std::runtime_error, " T-MAC kernel parameters not initialized" );
56+ MLAS_THROW_EX (std::runtime_error, " T-MAC kernel parameters not initialized for key: " + key );
4157}
4258
4359void MLASCALL
4460MlasClearLutGemmKernelConfig ()
4561{
62+ std::lock_guard<std::mutex> lock (tmac_kernel_configs_mutex);
4663 tmac_kernel_configs.clear ();
4764}
4865
4966void MLASCALL
5067MlasInitLutGemmKernelConfig (size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
5168{
52- std::string key = std::to_string (M) + " _" + std::to_string (N) + " _" + std::to_string (nbits) + " _" + std::to_string (block_size) + " _" + (has_zero_point ? " 1" : " 0" );
53- if (tmac_kernel_configs.count (key)) {
54- return ;
69+ std::string key = GetTmacKey (M, N, nbits, block_size, has_zero_point);
70+ {
71+ std::lock_guard<std::mutex> lock (tmac_kernel_configs_mutex);
72+ if (tmac_kernel_configs.find (key) != tmac_kernel_configs.end ()) {
73+ return ;
74+ }
5575 }
5676
5777 MlasTMACKernelParams params;
@@ -121,7 +141,10 @@ MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size,
121141 params.has_zero_point = has_zero_point;
122142 params.one_scale = false ; // TODO(vraspar): support one scale case for bitnet
123143
124- tmac_kernel_configs[key] = params;
144+ {
145+ std::lock_guard<std::mutex> lock (tmac_kernel_configs_mutex);
146+ tmac_kernel_configs[key] = params;
147+ }
125148 return ;
126149}
127150
@@ -222,53 +245,52 @@ LutGemmPackQuantBData(
222245 const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem);
223246 memset (PackedQuantBDataBegin, 0 , PackedQuantBDataSize); // TODO: is this needed?
224247
225- MlasTrySimpleParallel (
226- ThreadPool, Iterations,
227- [&](ptrdiff_t tid) {
228- size_t im = static_cast <size_t >(tid);
229- for (size_t ib = 0 ; ib < bits; ib++) {
230- for (size_t ik = 0 ; ik < K / g; ik++) {
231- // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
232- size_t new_im = im / simd_n_out;
233- size_t new_isno = im % simd_n_out;
234- size_t new_ib = ib;
235- size_t new_ik = ik;
236- size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;
237-
238- // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
239- new_im = new_idx / c1_nb0;
240- size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
241- size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
242- new_ik = (new_idx % c1_nb2);
243- new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;
244-
245- // # 0 1 2 3 4 5
246- // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
247- new_im = new_idx / c2_nb0;
248- size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
249- new_isni = (new_idx % c2_nb1) / c2_nb2;
250- new_ing = (new_idx % c2_nb2) / c2_nb3;
251- new_ik = (new_idx % c2_nb3) / c2_nb4;
252- size_t new_ikf = (new_idx % c2_nb4);
253- new_idx = new_im * c2_fac0 +
254- new_ik * c2_fac1 +
255- new_ibm * c2_fac2 +
256- new_ikf * c2_fac3 +
257- new_isni * ngroups_per_elem +
258- new_ing;
259- new_idx = new_idx / ngroups_per_elem;
260- size_t buf_idx = im * bits * K / g + ib * K / g + ik;
261- uint8_t buf_val = buf[buf_idx];
262-
263- // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
264- PackedQuantBDataBegin[new_idx] = static_cast <std::byte>(
265- static_cast <unsigned >(PackedQuantBDataBegin[new_idx]) +
266- (buf_val << (new_ing * g))
267- );
268- }
248+ // NOTE: The second packing loop is intentionally serialized to avoid data races.
249+ // T-MAC packs multiple output features (N) into a single byte if ngroups_per_elem > 1.
250+ // Parallelizing this across N would lead to concurrent bit-plane updates on the same memory location.
251+ for (size_t im = 0 ; im < Iterations; im++) {
252+ for (size_t ib = 0 ; ib < bits; ib++) {
253+ for (size_t ik = 0 ; ik < K / g; ik++) {
254+ // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
255+ size_t new_im = im / simd_n_out;
256+ size_t new_isno = im % simd_n_out;
257+ size_t new_ib = ib;
258+ size_t new_ik = ik;
259+ size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;
260+
261+ // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
262+ new_im = new_idx / c1_nb0;
263+ size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
264+ size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
265+ new_ik = (new_idx % c1_nb2);
266+ new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;
267+
268+ // # 0 1 2 3 4 5
269+ // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
270+ new_im = new_idx / c2_nb0;
271+ size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
272+ new_isni = (new_idx % c2_nb1) / c2_nb2;
273+ new_ing = (new_idx % c2_nb2) / c2_nb3;
274+ new_ik = (new_idx % c2_nb3) / c2_nb4;
275+ size_t new_ikf = (new_idx % c2_nb4);
276+ new_idx = new_im * c2_fac0 +
277+ new_ik * c2_fac1 +
278+ new_ibm * c2_fac2 +
279+ new_ikf * c2_fac3 +
280+ new_isni * ngroups_per_elem +
281+ new_ing;
282+ new_idx = new_idx / ngroups_per_elem;
283+ size_t buf_idx = im * bits * K / g + ib * K / g + ik;
284+ uint8_t buf_val = buf[buf_idx];
285+
286+ // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
287+ PackedQuantBDataBegin[new_idx] = static_cast <std::byte>(
288+ static_cast <unsigned >(PackedQuantBDataBegin[new_idx]) +
289+ (buf_val << (new_ing * g))
290+ );
269291 }
270292 }
271- );
293+ }
272294}
273295
274296// Internal helper: calculates packed scales and zero points size in floats
@@ -472,16 +494,15 @@ size_t
472494CalculateLutBufferSize (size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params)
473495{
474496 MLAS_UNREFERENCED_PARAMETER (n);
475- constexpr size_t kAllockAligment = 64 ;
476497 const size_t lut_scales_size = k / tmac_params.act_group_size ;
477498
478- size_t wsize = k * m * 4 * sizeof (int8_t ); // 4 bytes per k element for 2-bit LUT
479- wsize += lut_scales_size * m * 2 * sizeof (float ); // scales + biases
480-
481- wsize = ((wsize - 1 ) / kAllockAligment + 1 ) * kAllockAligment ;
499+ // The AVX2 kernel (g=4) expects 16 entries (16 bytes) per group of 4 activations.
500+ // This effectively requires 4 bytes per activation in the K dimension.
501+ size_t lut_size_bytes = m * k * 4 ;
502+ size_t scales_size_bytes = m * lut_scales_size * sizeof (float );
503+ size_t biases_size_bytes = m * lut_scales_size * sizeof (float );
482504
483- // TODO(vrapar): add temp buffer for FP16
484- return wsize;
505+ return lut_size_bytes + scales_size_bytes + biases_size_bytes + 256 ; // + alignment/safety padding
485506}
486507
487508void MLASCALL
@@ -532,17 +553,23 @@ MlasLutGemm(
532553 // n_tiles_num = m * bits / bm;
533554
534555 // TODO(vraspar): support other bitwidths
556+ // For T-MAC, kernel properties (bm, n_tiles_num) are primarily driven by the number of output features (N).
557+ // Initialization during packing (LutGemmPackQuantBDataSize) uses N as the major dimension,
558+ // so we must match that here to ensure consistent weight tiling.
559+ MlasInitLutGemmKernelConfig (N, K, 2 , BlkLen, HasZeroPoint);
535560 const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams (N, K, 2 , BlkLen, HasZeroPoint);
536561 const size_t lut_scales_size = K / tmac_params.act_group_size ;
562+ const size_t lut_size_bytes = static_cast <size_t >(M) * static_cast <size_t >(K) * 4 ;
537563 size_t lut_buffer_size = CalculateLutBufferSize (N, K, M, tmac_params);
538564
539565 // make buffer of lut_buffer_size bytes
540566 // TODO(vraspar): other way to do it
541567 auto lut_buffer = std::make_unique<int8_t []>(lut_buffer_size);
568+ memset (lut_buffer.get (), 0 , lut_buffer_size);
542569
543570 int8_t * qlut = reinterpret_cast <int8_t *>(lut_buffer.get ());
544- float * lut_scales = reinterpret_cast <float *>(qlut + K * M * 4 ); // after lut
545- float * lut_biases = reinterpret_cast <float *>(lut_scales + lut_scales_size * M); // after scales
571+ float * lut_scales = reinterpret_cast <float *>(qlut + lut_size_bytes ); // after lut
572+ float * lut_biases = reinterpret_cast <float *>(lut_scales + lut_scales_size * M); // after scales
546573
547574 const auto * a_float = reinterpret_cast <const float *>(A); // Activation data
548575
@@ -558,11 +585,9 @@ MlasLutGemm(
558585
559586 for (size_t ine11 = 0 ; ine11 < static_cast <size_t >(M); ine11++) {
560587 const size_t row_offset = ine11 * K;
561- const size_t lut_offset = ine11 * K * 4 ; // 4 bytes per K element for 2-bit LUT
562- const size_t scale_bias_offset = ine11 * lut_scales_size;
563-
564- // Call the dispatch function for this row
565- // ggml_tmac_mul_mat_task_init
588+ // Call the LUT generation kernel for this activation row.
589+ // We use a 4-byte stride (per activation) for the LUT entries to satisfy
590+ // the memory layout requirements of the computation kernel.
566591 Dispatch->GenerateLUT (
567592 const_cast <float *>(a_float + row_offset), // Input activation for this row
568593 qlut + lut_offset, // Output LUT for this row
@@ -571,7 +596,8 @@ MlasLutGemm(
571596 M,
572597 K,
573598 N,
574- tmac_params.act_group_size
599+ tmac_params.act_group_size ,
600+ tmac_params.act_group_size * 4
575601 );
576602 }
577603
@@ -657,15 +683,17 @@ MlasLutGemm(
657683
658684 // Process all batch items in this chunk
659685 for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) {
660- // Calculate LUT offsets for this batch item
686+ // Calculate LUT offsets with 4-byte stride (per activation) for consistent access.
661687 const size_t qlut_offset = K * ine11 * 4 ;
662688 const size_t lut_scales_offset = lut_scales_size * ine11;
663689
664690 // Calculate output offset
665691 const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0;
666692
667- // Call the dispatch function to compute this tile
668- // Note M and N are swapped in TMAC terminology
693+ // Call the dispatch function to compute this tile.
694+ // We pass one batch item at a time (M=1) and ChunkSize0 output features.
695+ // TotalN is passed specifically to allow the kernel to find the correct
696+ // parameters (bm, tiles) used during weight packing.
669697 Dispatch->ComputeGemm (
670698 packed_weights + w_offset, // Weight tile
671699 QuantBScale + scales_offset, // Weight scales for this tile
@@ -674,8 +702,9 @@ MlasLutGemm(
674702 lut_biases + lut_scales_offset, // LUT biases
675703 act_output + dst_offset, // Output location
676704 static_cast <int >(K), // K dimension
677- static_cast <int >(N), // N dimension
678- static_cast <int >(1 ), // M dimension (processing one batch item at a time)
705+ static_cast <int >(1 ), // M dimension (batch size = 1)
706+ static_cast <int >(ir0_end - ir0_start), // N dimension (output features in chunk)
707+ static_cast <int >(N), // TotalN (total output features in weights)
679708 BlkLen, // Weight quantization group size
680709 HasZeroPoint // Whether zero points are used
681710 );
0 commit comments