Skip to content

Commit 980b681

Browse files
committed
Add: Float6 vector-vector x86 kernels
On Skylake: 5-6 GB/s On Genoa: 9-10 GB/s
1 parent 17aed31 commit 980b681

File tree

4 files changed

+312
-0
lines changed

4 files changed

+312
-0
lines changed

include/numkong/numkong.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2534,6 +2534,14 @@ NK_INTERNAL void nk_find_kernel_punned_e5m2_(nk_capability_t v, nk_kernel_kind_t
25342534
NK_INTERNAL void nk_find_kernel_punned_e2m3_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_punned_t *m,
25352535
nk_capability_t *c) {
25362536
typedef nk_kernel_punned_t m_t;
2537+
#if NK_TARGET_GENOA
2538+
if (v & nk_cap_genoa_k) switch (k) {
2539+
case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_e2m3_genoa, *c = nk_cap_genoa_k; return;
2540+
case nk_kernel_sqeuclidean_k: *m = (m_t)&nk_sqeuclidean_e2m3_genoa, *c = nk_cap_genoa_k; return;
2541+
case nk_kernel_angular_k: *m = (m_t)&nk_angular_e2m3_genoa, *c = nk_cap_genoa_k; return;
2542+
default: break;
2543+
}
2544+
#endif
25372545
#if NK_TARGET_NEONFHM
25382546
if (v & nk_cap_neonfhm_k) switch (k) {
25392547
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e2m3_neonfhm, *c = nk_cap_neonfhm_k; return;
@@ -2543,6 +2551,14 @@ NK_INTERNAL void nk_find_kernel_punned_e2m3_(nk_capability_t v, nk_kernel_kind_t
25432551
case nk_kernel_dots_symmetric_k: *m = (m_t)&nk_dots_symmetric_e2m3_neonfhm, *c = nk_cap_neonfhm_k; return;
25442552
default: break;
25452553
}
2554+
#endif
2555+
#if NK_TARGET_SKYLAKE
2556+
if (v & nk_cap_skylake_k) switch (k) {
2557+
case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_e2m3_skylake, *c = nk_cap_skylake_k; return;
2558+
case nk_kernel_sqeuclidean_k: *m = (m_t)&nk_sqeuclidean_e2m3_skylake, *c = nk_cap_skylake_k; return;
2559+
case nk_kernel_angular_k: *m = (m_t)&nk_angular_e2m3_skylake, *c = nk_cap_skylake_k; return;
2560+
default: break;
2561+
}
25462562
#endif
25472563
if (v & nk_cap_serial_k) switch (k) {
25482564
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e2m3_serial, *c = nk_cap_serial_k; return;
@@ -2560,6 +2576,14 @@ NK_INTERNAL void nk_find_kernel_punned_e2m3_(nk_capability_t v, nk_kernel_kind_t
25602576
NK_INTERNAL void nk_find_kernel_punned_e3m2_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_punned_t *m,
25612577
nk_capability_t *c) {
25622578
typedef nk_kernel_punned_t m_t;
2579+
#if NK_TARGET_GENOA
2580+
if (v & nk_cap_genoa_k) switch (k) {
2581+
case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_e3m2_genoa, *c = nk_cap_genoa_k; return;
2582+
case nk_kernel_sqeuclidean_k: *m = (m_t)&nk_sqeuclidean_e3m2_genoa, *c = nk_cap_genoa_k; return;
2583+
case nk_kernel_angular_k: *m = (m_t)&nk_angular_e3m2_genoa, *c = nk_cap_genoa_k; return;
2584+
default: break;
2585+
}
2586+
#endif
25632587
#if NK_TARGET_NEONFHM
25642588
if (v & nk_cap_neonfhm_k) switch (k) {
25652589
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e3m2_neonfhm, *c = nk_cap_neonfhm_k; return;
@@ -2569,6 +2593,14 @@ NK_INTERNAL void nk_find_kernel_punned_e3m2_(nk_capability_t v, nk_kernel_kind_t
25692593
case nk_kernel_dots_symmetric_k: *m = (m_t)&nk_dots_symmetric_e3m2_neonfhm, *c = nk_cap_neonfhm_k; return;
25702594
default: break;
25712595
}
2596+
#endif
2597+
#if NK_TARGET_SKYLAKE
2598+
if (v & nk_cap_skylake_k) switch (k) {
2599+
case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_e3m2_skylake, *c = nk_cap_skylake_k; return;
2600+
case nk_kernel_sqeuclidean_k: *m = (m_t)&nk_sqeuclidean_e3m2_skylake, *c = nk_cap_skylake_k; return;
2601+
case nk_kernel_angular_k: *m = (m_t)&nk_angular_e3m2_skylake, *c = nk_cap_skylake_k; return;
2602+
default: break;
2603+
}
25722604
#endif
25732605
if (v & nk_cap_serial_k) switch (k) {
25742606
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e3m2_serial, *c = nk_cap_serial_k; return;

include/numkong/spatial/genoa.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,128 @@ NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_
293293
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
294294
}
295295

296+
NK_PUBLIC void nk_sqeuclidean_e2m3_genoa(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
297+
__m512 distance_sq_f32x16 = _mm512_setzero_ps();
298+
__m256i a_e2m3x32, b_e2m3x32;
299+
300+
nk_sqeuclidean_e2m3_genoa_cycle:
301+
if (n < 32) {
302+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
303+
a_e2m3x32 = _mm256_maskz_loadu_epi8(mask, a);
304+
b_e2m3x32 = _mm256_maskz_loadu_epi8(mask, b);
305+
n = 0;
306+
}
307+
else {
308+
a_e2m3x32 = _mm256_loadu_epi8(a);
309+
b_e2m3x32 = _mm256_loadu_epi8(b);
310+
a += 32, b += 32, n -= 32;
311+
}
312+
__m512i a_bf16x32 = nk_e2m3x32_to_bf16x32_ice_(a_e2m3x32);
313+
__m512i b_bf16x32 = nk_e2m3x32_to_bf16x32_ice_(b_e2m3x32);
314+
__m512i diff_bf16x32 = nk_substract_bf16x32_genoa_(a_bf16x32, b_bf16x32);
315+
distance_sq_f32x16 = _mm512_dpbf16_ps(distance_sq_f32x16, (__m512bh)(diff_bf16x32), (__m512bh)(diff_bf16x32));
316+
if (n) goto nk_sqeuclidean_e2m3_genoa_cycle;
317+
318+
*result = nk_reduce_add_f32x16_skylake_(distance_sq_f32x16);
319+
}
320+
321+
NK_PUBLIC void nk_euclidean_e2m3_genoa(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
322+
nk_sqeuclidean_e2m3_genoa(a, b, n, result);
323+
*result = nk_sqrt_f32_haswell_(*result);
324+
}
325+
326+
NK_PUBLIC void nk_angular_e2m3_genoa(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
327+
__m512 dot_f32x16 = _mm512_setzero_ps();
328+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
329+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
330+
__m256i a_e2m3x32, b_e2m3x32;
331+
332+
nk_angular_e2m3_genoa_cycle:
333+
if (n < 32) {
334+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
335+
a_e2m3x32 = _mm256_maskz_loadu_epi8(mask, a);
336+
b_e2m3x32 = _mm256_maskz_loadu_epi8(mask, b);
337+
n = 0;
338+
}
339+
else {
340+
a_e2m3x32 = _mm256_loadu_epi8(a);
341+
b_e2m3x32 = _mm256_loadu_epi8(b);
342+
a += 32, b += 32, n -= 32;
343+
}
344+
__m512i a_bf16x32 = nk_e2m3x32_to_bf16x32_ice_(a_e2m3x32);
345+
__m512i b_bf16x32 = nk_e2m3x32_to_bf16x32_ice_(b_e2m3x32);
346+
dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, (__m512bh)(a_bf16x32), (__m512bh)(b_bf16x32));
347+
a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, (__m512bh)(a_bf16x32), (__m512bh)(a_bf16x32));
348+
b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, (__m512bh)(b_bf16x32), (__m512bh)(b_bf16x32));
349+
if (n) goto nk_angular_e2m3_genoa_cycle;
350+
351+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
352+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
353+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
354+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
355+
}
356+
357+
NK_PUBLIC void nk_sqeuclidean_e3m2_genoa(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
358+
__m512 distance_sq_f32x16 = _mm512_setzero_ps();
359+
__m256i a_e3m2x32, b_e3m2x32;
360+
361+
nk_sqeuclidean_e3m2_genoa_cycle:
362+
if (n < 32) {
363+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
364+
a_e3m2x32 = _mm256_maskz_loadu_epi8(mask, a);
365+
b_e3m2x32 = _mm256_maskz_loadu_epi8(mask, b);
366+
n = 0;
367+
}
368+
else {
369+
a_e3m2x32 = _mm256_loadu_epi8(a);
370+
b_e3m2x32 = _mm256_loadu_epi8(b);
371+
a += 32, b += 32, n -= 32;
372+
}
373+
__m512i a_bf16x32 = nk_e3m2x32_to_bf16x32_ice_(a_e3m2x32);
374+
__m512i b_bf16x32 = nk_e3m2x32_to_bf16x32_ice_(b_e3m2x32);
375+
__m512i diff_bf16x32 = nk_substract_bf16x32_genoa_(a_bf16x32, b_bf16x32);
376+
distance_sq_f32x16 = _mm512_dpbf16_ps(distance_sq_f32x16, (__m512bh)(diff_bf16x32), (__m512bh)(diff_bf16x32));
377+
if (n) goto nk_sqeuclidean_e3m2_genoa_cycle;
378+
379+
*result = nk_reduce_add_f32x16_skylake_(distance_sq_f32x16);
380+
}
381+
382+
NK_PUBLIC void nk_euclidean_e3m2_genoa(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
383+
nk_sqeuclidean_e3m2_genoa(a, b, n, result);
384+
*result = nk_sqrt_f32_haswell_(*result);
385+
}
386+
387+
NK_PUBLIC void nk_angular_e3m2_genoa(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
388+
__m512 dot_f32x16 = _mm512_setzero_ps();
389+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
390+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
391+
__m256i a_e3m2x32, b_e3m2x32;
392+
393+
nk_angular_e3m2_genoa_cycle:
394+
if (n < 32) {
395+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
396+
a_e3m2x32 = _mm256_maskz_loadu_epi8(mask, a);
397+
b_e3m2x32 = _mm256_maskz_loadu_epi8(mask, b);
398+
n = 0;
399+
}
400+
else {
401+
a_e3m2x32 = _mm256_loadu_epi8(a);
402+
b_e3m2x32 = _mm256_loadu_epi8(b);
403+
a += 32, b += 32, n -= 32;
404+
}
405+
__m512i a_bf16x32 = nk_e3m2x32_to_bf16x32_ice_(a_e3m2x32);
406+
__m512i b_bf16x32 = nk_e3m2x32_to_bf16x32_ice_(b_e3m2x32);
407+
dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, (__m512bh)(a_bf16x32), (__m512bh)(b_bf16x32));
408+
a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, (__m512bh)(a_bf16x32), (__m512bh)(a_bf16x32));
409+
b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, (__m512bh)(b_bf16x32), (__m512bh)(b_bf16x32));
410+
if (n) goto nk_angular_e3m2_genoa_cycle;
411+
412+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
413+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
414+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
415+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
416+
}
417+
296418
#if defined(__cplusplus)
297419
} // extern "C"
298420
#endif

include/numkong/spatial/skylake.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,128 @@ NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, n
577577
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
578578
}
579579

580+
NK_PUBLIC void nk_sqeuclidean_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
581+
__m512 sum_f32x16 = _mm512_setzero_ps();
582+
__m128i a_e2m3x16, b_e2m3x16;
583+
584+
nk_sqeuclidean_e2m3_skylake_cycle:
585+
if (n < 16) {
586+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
587+
a_e2m3x16 = _mm_maskz_loadu_epi8(mask, a);
588+
b_e2m3x16 = _mm_maskz_loadu_epi8(mask, b);
589+
n = 0;
590+
}
591+
else {
592+
a_e2m3x16 = _mm_loadu_si128((__m128i const *)a);
593+
b_e2m3x16 = _mm_loadu_si128((__m128i const *)b);
594+
a += 16, b += 16, n -= 16;
595+
}
596+
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3x16);
597+
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3x16);
598+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
599+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
600+
if (n) goto nk_sqeuclidean_e2m3_skylake_cycle;
601+
602+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
603+
}
604+
605+
NK_PUBLIC void nk_euclidean_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
606+
nk_sqeuclidean_e2m3_skylake(a, b, n, result);
607+
*result = nk_sqrt_f32_haswell_(*result);
608+
}
609+
610+
NK_PUBLIC void nk_angular_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
611+
__m512 dot_f32x16 = _mm512_setzero_ps();
612+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
613+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
614+
__m128i a_e2m3x16, b_e2m3x16;
615+
616+
nk_angular_e2m3_skylake_cycle:
617+
if (n < 16) {
618+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
619+
a_e2m3x16 = _mm_maskz_loadu_epi8(mask, a);
620+
b_e2m3x16 = _mm_maskz_loadu_epi8(mask, b);
621+
n = 0;
622+
}
623+
else {
624+
a_e2m3x16 = _mm_loadu_si128((__m128i const *)a);
625+
b_e2m3x16 = _mm_loadu_si128((__m128i const *)b);
626+
a += 16, b += 16, n -= 16;
627+
}
628+
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3x16);
629+
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3x16);
630+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
631+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
632+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
633+
if (n) goto nk_angular_e2m3_skylake_cycle;
634+
635+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
636+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
637+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
638+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
639+
}
640+
641+
NK_PUBLIC void nk_sqeuclidean_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
642+
__m512 sum_f32x16 = _mm512_setzero_ps();
643+
__m128i a_e3m2x16, b_e3m2x16;
644+
645+
nk_sqeuclidean_e3m2_skylake_cycle:
646+
if (n < 16) {
647+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
648+
a_e3m2x16 = _mm_maskz_loadu_epi8(mask, a);
649+
b_e3m2x16 = _mm_maskz_loadu_epi8(mask, b);
650+
n = 0;
651+
}
652+
else {
653+
a_e3m2x16 = _mm_loadu_si128((__m128i const *)a);
654+
b_e3m2x16 = _mm_loadu_si128((__m128i const *)b);
655+
a += 16, b += 16, n -= 16;
656+
}
657+
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2x16);
658+
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2x16);
659+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
660+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
661+
if (n) goto nk_sqeuclidean_e3m2_skylake_cycle;
662+
663+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
664+
}
665+
666+
NK_PUBLIC void nk_euclidean_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
667+
nk_sqeuclidean_e3m2_skylake(a, b, n, result);
668+
*result = nk_sqrt_f32_haswell_(*result);
669+
}
670+
671+
NK_PUBLIC void nk_angular_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
672+
__m512 dot_f32x16 = _mm512_setzero_ps();
673+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
674+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
675+
__m128i a_e3m2x16, b_e3m2x16;
676+
677+
nk_angular_e3m2_skylake_cycle:
678+
if (n < 16) {
679+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
680+
a_e3m2x16 = _mm_maskz_loadu_epi8(mask, a);
681+
b_e3m2x16 = _mm_maskz_loadu_epi8(mask, b);
682+
n = 0;
683+
}
684+
else {
685+
a_e3m2x16 = _mm_loadu_si128((__m128i const *)a);
686+
b_e3m2x16 = _mm_loadu_si128((__m128i const *)b);
687+
a += 16, b += 16, n -= 16;
688+
}
689+
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2x16);
690+
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2x16);
691+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
692+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
693+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
694+
if (n) goto nk_angular_e3m2_skylake_cycle;
695+
696+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
697+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
698+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
699+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
700+
}
701+
580702
#if defined(__cplusplus)
581703
} // extern "C"
582704
#endif

scripts/bench.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,18 @@ int main(int argc, char **argv) {
15621562

15631563
dense_<e4m3_k, f32_k>("dot_e4m3_skylake", nk_dot_e4m3_skylake);
15641564
dense_<e5m2_k, f32_k>("dot_e5m2_skylake", nk_dot_e5m2_skylake);
1565+
dense_<e4m3_k, f32_k>("angular_e4m3_skylake", nk_angular_e4m3_skylake);
1566+
dense_<e4m3_k, f32_k>("sqeuclidean_e4m3_skylake", nk_sqeuclidean_e4m3_skylake);
1567+
dense_<e4m3_k, f32_k>("euclidean_e4m3_skylake", nk_euclidean_e4m3_skylake);
1568+
dense_<e5m2_k, f32_k>("angular_e5m2_skylake", nk_angular_e5m2_skylake);
1569+
dense_<e5m2_k, f32_k>("sqeuclidean_e5m2_skylake", nk_sqeuclidean_e5m2_skylake);
1570+
dense_<e5m2_k, f32_k>("euclidean_e5m2_skylake", nk_euclidean_e5m2_skylake);
1571+
dense_<e2m3_k, f32_k>("angular_e2m3_skylake", nk_angular_e2m3_skylake);
1572+
dense_<e2m3_k, f32_k>("sqeuclidean_e2m3_skylake", nk_sqeuclidean_e2m3_skylake);
1573+
dense_<e2m3_k, f32_k>("euclidean_e2m3_skylake", nk_euclidean_e2m3_skylake);
1574+
dense_<e3m2_k, f32_k>("angular_e3m2_skylake", nk_angular_e3m2_skylake);
1575+
dense_<e3m2_k, f32_k>("sqeuclidean_e3m2_skylake", nk_sqeuclidean_e3m2_skylake);
1576+
dense_<e3m2_k, f32_k>("euclidean_e3m2_skylake", nk_euclidean_e3m2_skylake);
15651577

15661578
dense_<i8_k, i32_k>("dot_i8_skylake", nk_dot_i8_skylake);
15671579
dense_<u8_k, u32_k>("dot_u8_skylake", nk_dot_u8_skylake);
@@ -1657,6 +1669,18 @@ int main(int argc, char **argv) {
16571669

16581670
dense_<e4m3_k, f32_k>("dot_e4m3_genoa", nk_dot_e4m3_genoa);
16591671
dense_<e5m2_k, f32_k>("dot_e5m2_genoa", nk_dot_e5m2_genoa);
1672+
dense_<e4m3_k, f32_k>("angular_e4m3_genoa", nk_angular_e4m3_genoa);
1673+
dense_<e4m3_k, f32_k>("sqeuclidean_e4m3_genoa", nk_sqeuclidean_e4m3_genoa);
1674+
dense_<e4m3_k, f32_k>("euclidean_e4m3_genoa", nk_euclidean_e4m3_genoa);
1675+
dense_<e5m2_k, f32_k>("angular_e5m2_genoa", nk_angular_e5m2_genoa);
1676+
dense_<e5m2_k, f32_k>("sqeuclidean_e5m2_genoa", nk_sqeuclidean_e5m2_genoa);
1677+
dense_<e5m2_k, f32_k>("euclidean_e5m2_genoa", nk_euclidean_e5m2_genoa);
1678+
dense_<e2m3_k, f32_k>("angular_e2m3_genoa", nk_angular_e2m3_genoa);
1679+
dense_<e2m3_k, f32_k>("sqeuclidean_e2m3_genoa", nk_sqeuclidean_e2m3_genoa);
1680+
dense_<e2m3_k, f32_k>("euclidean_e2m3_genoa", nk_euclidean_e2m3_genoa);
1681+
dense_<e3m2_k, f32_k>("angular_e3m2_genoa", nk_angular_e3m2_genoa);
1682+
dense_<e3m2_k, f32_k>("sqeuclidean_e3m2_genoa", nk_sqeuclidean_e3m2_genoa);
1683+
dense_<e3m2_k, f32_k>("euclidean_e3m2_genoa", nk_euclidean_e3m2_genoa);
16601684

16611685
curved_<bf16_k, f32_k>("bilinear_bf16_genoa", nk_bilinear_bf16_genoa);
16621686
curved_<bf16_k, f32_k>("mahalanobis_bf16_genoa", nk_mahalanobis_bf16_genoa);
@@ -1762,8 +1786,20 @@ int main(int argc, char **argv) {
17621786

17631787
dense_<e4m3_k, f32_k>("dot_e4m3_serial", nk_dot_e4m3_serial);
17641788
dense_<e5m2_k, f32_k>("dot_e5m2_serial", nk_dot_e5m2_serial);
1789+
dense_<e4m3_k, f32_k>("angular_e4m3_serial", nk_angular_e4m3_serial);
1790+
dense_<e4m3_k, f32_k>("sqeuclidean_e4m3_serial", nk_sqeuclidean_e4m3_serial);
1791+
dense_<e4m3_k, f32_k>("euclidean_e4m3_serial", nk_euclidean_e4m3_serial);
1792+
dense_<e5m2_k, f32_k>("angular_e5m2_serial", nk_angular_e5m2_serial);
1793+
dense_<e5m2_k, f32_k>("sqeuclidean_e5m2_serial", nk_sqeuclidean_e5m2_serial);
1794+
dense_<e5m2_k, f32_k>("euclidean_e5m2_serial", nk_euclidean_e5m2_serial);
17651795
dense_<e2m3_k, f32_k>("dot_e2m3_serial", nk_dot_e2m3_serial);
1796+
dense_<e2m3_k, f32_k>("angular_e2m3_serial", nk_angular_e2m3_serial);
1797+
dense_<e2m3_k, f32_k>("sqeuclidean_e2m3_serial", nk_sqeuclidean_e2m3_serial);
1798+
dense_<e2m3_k, f32_k>("euclidean_e2m3_serial", nk_euclidean_e2m3_serial);
17661799
dense_<e3m2_k, f32_k>("dot_e3m2_serial", nk_dot_e3m2_serial);
1800+
dense_<e3m2_k, f32_k>("angular_e3m2_serial", nk_angular_e3m2_serial);
1801+
dense_<e3m2_k, f32_k>("sqeuclidean_e3m2_serial", nk_sqeuclidean_e3m2_serial);
1802+
dense_<e3m2_k, f32_k>("euclidean_e3m2_serial", nk_euclidean_e3m2_serial);
17671803

17681804
dense_<f16_k, f32_k>("dot_f16_serial", nk_dot_f16_serial);
17691805
dense_<f16_k, f32_k>("angular_f16_serial", nk_angular_f16_serial);

0 commit comments

Comments
 (0)