144144#define NK_DOTS_H
145145
146146#include "numkong/types.h"
147- #include "numkong/dot.h" // nk_bf16x16_to_f32x16_skylake_
148147
149148#if defined(__cplusplus )
150149extern "C" {
@@ -253,6 +252,8 @@ NK_DYNAMIC void nk_dots_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_u
253252 * @param[in] stride Row stride in bytes for the input matrix.
254253 * @param[out] result Output symmetric matrix (n_vectors × n_vectors).
255254 * @param[in] result_stride Row stride in bytes for the result matrix.
255+ * @param[in] row_start Starting row offset of results to compute (needed for parallelism).
256+ * @param[in] row_count Number of rows of results to compute (needed for parallelism).
256257 */
257258NK_DYNAMIC void nk_dots_symmetric_bf16 (nk_bf16_t const * vectors , nk_size_t n_vectors , nk_size_t depth , nk_size_t stride ,
258259 nk_f32_t * result , nk_size_t result_stride , nk_size_t row_start ,
@@ -294,6 +295,38 @@ NK_DYNAMIC void nk_dots_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vecto
294295 nk_u32_t * result , nk_size_t result_stride , nk_size_t row_start ,
295296 nk_size_t row_count );
296297
298+ /**
299+ * @brief Compacts f32 GEMM output to bf16 (in-place).
300+ *
301+ * After computing C_f32 = A × Bᵀ in f32, truncates to bf16 with rounding.
302+ * The operation is done in-place: reads f32 values and writes bf16 to the same buffer.
303+ * Output is tightly packed with stride = n × sizeof(bf16).
304+ *
305+ * @param c Buffer containing f32 values, will be overwritten with bf16 output (m × n).
306+ * @param m Number of rows.
307+ * @param n Number of columns.
308+ * @param c_stride Row stride of input f32 matrix in bytes.
309+ */
310+ NK_DYNAMIC void nk_dots_compact_bf16 (void * c , nk_size_t m , nk_size_t n , nk_size_t c_stride );
311+
312+ /**
313+ * @brief Compacts i32 GEMM output to normalized i8 (in-place).
314+ *
315+ * After computing C_i32 = A × Bᵀ in i32, normalizes to cosine similarity in [-128, 127].
316+ * Uses squared norms for normalization: result[i,j] = 127 × C[i,j] / sqrt(a_norm[i] × b_norm[j]).
317+ * The operation is done in-place: reads i32 values and writes i8 to the same buffer.
318+ * Output is tightly packed with stride = n × sizeof(i8).
319+ *
320+ * @param c Buffer containing i32 values, will be overwritten with i8 output (m × n).
321+ * @param m Number of rows.
322+ * @param n Number of columns.
323+ * @param c_stride Row stride of input i32 matrix in bytes.
324+ * @param a_squared_norms Squared L2 norms for A rows (length m).
325+ * @param b_squared_norms Squared L2 norms for B rows (length n).
326+ */
327+ NK_DYNAMIC void nk_dots_compact_i8 (void * c , nk_size_t m , nk_size_t n , nk_size_t c_stride ,
328+ nk_i32_t const * a_squared_norms , nk_i32_t const * b_squared_norms );
329+
297330/** @copydoc nk_dots_packed_size_f32 */
298331NK_PUBLIC nk_size_t nk_dots_packed_size_f32_serial (nk_size_t n , nk_size_t k );
299332/** @copydoc nk_dots_pack_f32 */
@@ -302,9 +335,8 @@ NK_PUBLIC void nk_dots_pack_f32_serial(nk_f32_t const *b, nk_size_t n, nk_size_t
302335NK_PUBLIC void nk_dots_packed_f32_serial (nk_f32_t const * a , void const * b_packed , nk_f32_t * c , nk_size_t m , nk_size_t n ,
303336 nk_size_t k , nk_size_t a_stride , nk_size_t c_stride );
304337/** @copydoc nk_dots_symmetric_f32 */
305- NK_PUBLIC void nk_dots_symmetric_f32_serial (nk_f32_t const * vectors , nk_size_t n_vectors , nk_size_t depth ,
306- nk_size_t stride , nk_f32_t * result , nk_size_t result_stride ,
307- nk_size_t row_start , nk_size_t row_count );
338+ NK_PUBLIC void 11 (nk_f32_t const * vectors , nk_size_t n_vectors , nk_size_t depth , nk_size_t stride , nk_f32_t * result ,
339+ nk_size_t result_stride , nk_size_t row_start , nk_size_t row_count );
308340
309341/** @copydoc nk_dots_packed_size_f64 */
310342NK_PUBLIC nk_size_t nk_dots_packed_size_f64_serial (nk_size_t n , nk_size_t k );
@@ -330,38 +362,6 @@ NK_PUBLIC void nk_dots_symmetric_f16_serial(nk_f16_t const *vectors, nk_size_t n
330362 nk_size_t stride , nk_f32_t * result , nk_size_t result_stride ,
331363 nk_size_t row_start , nk_size_t row_count );
332364
333- /**
334- * @brief Compacts f32 GEMM output to bf16 (in-place).
335- *
336- * After computing C_f32 = A × Bᵀ in f32, truncates to bf16 with rounding.
337- * The operation is done in-place: reads f32 values and writes bf16 to the same buffer.
338- * Output is tightly packed with stride = n × sizeof(bf16).
339- *
340- * @param c Buffer containing f32 values, will be overwritten with bf16 output (m × n).
341- * @param m Number of rows.
342- * @param n Number of columns.
343- * @param c_stride Row stride of input f32 matrix in bytes.
344- */
345- NK_DYNAMIC void nk_dots_compact_bf16 (void * c , nk_size_t m , nk_size_t n , nk_size_t c_stride );
346-
347- /**
348- * @brief Compacts i32 GEMM output to normalized i8 (in-place).
349- *
350- * After computing C_i32 = A × Bᵀ in i32, normalizes to cosine similarity in [-128, 127].
351- * Uses squared norms for normalization: result[i,j] = 127 × C[i,j] / sqrt(a_norm[i] × b_norm[j]).
352- * The operation is done in-place: reads i32 values and writes i8 to the same buffer.
353- * Output is tightly packed with stride = n × sizeof(i8).
354- *
355- * @param c Buffer containing i32 values, will be overwritten with i8 output (m × n).
356- * @param m Number of rows.
357- * @param n Number of columns.
358- * @param c_stride Row stride of input i32 matrix in bytes.
359- * @param a_squared_norms Squared L2 norms for A rows (length m).
360- * @param b_squared_norms Squared L2 norms for B rows (length n).
361- */
362- NK_DYNAMIC void nk_dots_compact_i8 (void * c , nk_size_t m , nk_size_t n , nk_size_t c_stride ,
363- nk_i32_t const * a_squared_norms , nk_i32_t const * b_squared_norms );
364-
365365/** @copydoc nk_dots_packed_size_bf16 */
366366NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_serial (nk_size_t n , nk_size_t k );
367367/** @copydoc nk_dots_pack_bf16 */
0 commit comments