Skip to content

Commit 5ba15e2

Browse files
committed
Add: Relaxed SIMD backend for WASM
With new FMA instructions in Relaxed WASM SIMD, we can implement fast dot products and spatial metrics for USearch in the browser & on edge. Wasmer is about to gain support: wasmerio/wasmer#6151
1 parent 378daae commit 5ba15e2

File tree

14 files changed

+1465
-27
lines changed

14 files changed

+1465
-27
lines changed

CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ project(
6262

6363
set(CMAKE_C_STANDARD 99)
6464
set(CMAKE_C_STANDARD_REQUIRED YES)
65-
set(CMAKE_C_EXTENSIONS NO)
65+
# Enable GNU extensions for WASM (required for EM_ASM runtime detection)
66+
if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
67+
set(CMAKE_C_EXTENSIONS YES)
68+
else()
69+
set(CMAKE_C_EXTENSIONS NO)
70+
endif()
6671

6772
set(CMAKE_CXX_STANDARD 23)
6873
set(CMAKE_CXX_STANDARD_REQUIRED YES)

cmake/toolchain-wasm.cmake

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# WASM/Emscripten toolchain for NumKong
2+
# Usage: cmake -B build-wasm -DCMAKE_TOOLCHAIN_FILE=cmake/toolchain-wasm.cmake
3+
4+
set(CMAKE_SYSTEM_NAME Emscripten)
5+
set(CMAKE_SYSTEM_PROCESSOR wasm32)
6+
7+
# Verify Emscripten SDK
8+
if(NOT DEFINED ENV{EMSDK})
9+
message(FATAL_ERROR
10+
"EMSDK environment variable not set.\n"
11+
"Install Emscripten: https://emscripten.org/docs/getting_started/downloads.html\n"
12+
"Then run: source $EMSDK/emsdk_env.sh")
13+
endif()
14+
15+
# Set compilers
16+
set(EMSCRIPTEN_ROOT "$ENV{EMSDK}/upstream/emscripten")
17+
set(CMAKE_C_COMPILER "${EMSCRIPTEN_ROOT}/emcc")
18+
set(CMAKE_CXX_COMPILER "${EMSCRIPTEN_ROOT}/em++")
19+
set(CMAKE_AR "${EMSCRIPTEN_ROOT}/emar")
20+
set(CMAKE_RANLIB "${EMSCRIPTEN_ROOT}/emranlib")
21+
22+
# Required WASM SIMD flags
23+
set(WASM_SIMD_FLAGS "-msimd128 -mrelaxed-simd")
24+
set(CMAKE_C_FLAGS_INIT "${WASM_SIMD_FLAGS}")
25+
set(CMAKE_CXX_FLAGS_INIT "${WASM_SIMD_FLAGS}")
26+
27+
# Enable GNU extensions for EM_ASM support (required for runtime detection)
28+
set(CMAKE_C_EXTENSIONS ON CACHE BOOL "" FORCE)
29+
set(CMAKE_CXX_EXTENSIONS ON CACHE BOOL "" FORCE)
30+
31+
# Optimization
32+
set(CMAKE_C_FLAGS_RELEASE "-O3 -DNDEBUG -flto")
33+
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG -flto")
34+
set(CMAKE_C_FLAGS_DEBUG "-O0 -g -s ASSERTIONS=2")
35+
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -s ASSERTIONS=2")
36+
37+
# Linker flags for Node.js execution
38+
set(CMAKE_EXE_LINKER_FLAGS_INIT
39+
"-s ALLOW_MEMORY_GROWTH=1 \
40+
-s INITIAL_MEMORY=64MB \
41+
-s MAXIMUM_MEMORY=2GB \
42+
-s STACK_SIZE=5MB \
43+
-s EXPORTED_FUNCTIONS='[\"_main\"]' \
44+
-s EXPORTED_RUNTIME_METHODS='[\"ccall\",\"cwrap\"]'")
45+
46+
# Verify Emscripten version (need 3.1.27+ for relaxed SIMD)
47+
execute_process(
48+
COMMAND ${CMAKE_C_COMPILER} --version
49+
OUTPUT_VARIABLE EMCC_VERSION_OUTPUT
50+
OUTPUT_STRIP_TRAILING_WHITESPACE)
51+
string(REGEX MATCH "[0-9]+\\.[0-9]+\\.[0-9]+" EMCC_VERSION "${EMCC_VERSION_OUTPUT}")
52+
53+
if(EMCC_VERSION VERSION_LESS "3.1.27")
54+
message(WARNING "Emscripten ${EMCC_VERSION} < 3.1.27. Upgrade recommended for relaxed SIMD.")
55+
endif()
56+
57+
message(STATUS "NumKong WASM: Emscripten ${EMCC_VERSION}")
58+
message(STATUS "NumKong WASM: Relaxed SIMD enabled")

include/numkong/cast/wasm.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/**
2+
* @file wasm.h
3+
* @brief WASM SIMD (v128) type conversion helpers for BF16/F16 to F32.
4+
* @author Ash Vardanian
5+
* @date January 31, 2026
6+
*/
7+
8+
#ifndef NK_CAST_WASM_H
9+
#define NK_CAST_WASM_H
10+
11+
#if NK_TARGET_V128RELAXED
12+
#include "numkong/types.h"
13+
#include "numkong/cast/serial.h" // For scalar fallback
14+
15+
#if defined(__cplusplus)
16+
extern "C" {
17+
#endif
18+
19+
NK_INTERNAL nk_b128_vec_t nk_bf16x4_to_f32x4_wasm_(nk_b64_vec_t bf16_vec) {
20+
// Load 4x u16 (64 bits) into lower half of v128, zero upper half
21+
v128_t bf16_u16x4_in_u64 = wasm_v128_load64_zero(&bf16_vec.u64);
22+
23+
// Widen u16 → u32: [u16, u16, u16, u16, 0, 0, 0, 0] → [u32, u32, u32, u32]
24+
// Uses zero-extension (upper 16 bits of each u32 become 0)
25+
v128_t bf16_u32x4_low = wasm_u32x4_extend_low_u16x8(bf16_u16x4_in_u64);
26+
27+
// Shift left by 16 bits: moves BF16 into F32 position
28+
// BF16: [S|EEEEEEEE|MMMMMMM|0000000000000000]
29+
// F32: [S|EEEEEEEE|MMMMMMM00000000000000000]
30+
nk_b128_vec_t result;
31+
result.v128 = wasm_i32x4_shl(bf16_u32x4_low, 16);
32+
return result;
33+
}
34+
35+
NK_INTERNAL nk_b128_vec_t nk_f16x4_to_f32x4_wasm_(nk_b64_vec_t f16_vec) {
36+
// Load 4x u16 into v128, zero-extend to u32x4
37+
v128_t f16_u16x4_in_u64 = wasm_v128_load64_zero(&f16_vec.u64);
38+
v128_t f16_u32x4 = wasm_u32x4_extend_low_u16x8(f16_u16x4_in_u64);
39+
40+
// Extract bit fields
41+
v128_t sign_u32x4 = wasm_v128_and(f16_u32x4, wasm_i32x4_splat(0x8000)); // Bit 15
42+
v128_t exp_u32x4 = wasm_v128_and(wasm_u32x4_shr(f16_u32x4, 10), wasm_i32x4_splat(0x1F)); // Bits 14-10
43+
v128_t mant_u32x4 = wasm_v128_and(f16_u32x4, wasm_i32x4_splat(0x03FF)); // Bits 9-0
44+
45+
// Shift sign to F32 position (bit 31)
46+
v128_t sign_f32_u32x4 = wasm_i32x4_shl(sign_u32x4, 16);
47+
48+
// Normal (exp ∈ [1, 30])
49+
// Rebias exponent: F16 bias=15, F32 bias=127 → add 112
50+
// Shift mantissa: 10 bits → 23 bits (shift left by 13)
51+
v128_t exp_rebiased_u32x4 = wasm_i32x4_add(exp_u32x4, wasm_i32x4_splat(112));
52+
v128_t normal_exp_u32x4 = wasm_i32x4_shl(exp_rebiased_u32x4, 23);
53+
v128_t normal_mant_u32x4 = wasm_i32x4_shl(mant_u32x4, 13);
54+
v128_t normal_bits_u32x4 = wasm_v128_or(sign_f32_u32x4, wasm_v128_or(normal_exp_u32x4, normal_mant_u32x4));
55+
56+
// Zero (exp=0, mant=0)
57+
v128_t zero_bits_u32x4 = sign_f32_u32x4; // Just sign bit
58+
59+
// Infinity/NaN (exp=31)
60+
// Infinity: 0x7F800000 | sign
61+
// NaN: 0x7F800000 | sign | (mant << 13) [preserves NaN payload]
62+
v128_t inf_nan_bits_u32x4 = wasm_v128_or(
63+
sign_f32_u32x4, wasm_v128_or(wasm_i32x4_splat(0x7F800000), wasm_i32x4_shl(mant_u32x4, 13)));
64+
65+
// Denormal (exp=0, mant≠0) - FPU-based normalization
66+
// F16 denormal value = 2^-14 × (0.mantissa_bits) = mantissa_bits × 2^-24
67+
//
68+
// Strategy: Use FPU to normalize by converting to float and multiplying by magic constant
69+
// 1. Convert mantissa (integer) to F32: cvt_u32_to_f32(mant)
70+
// 2. Multiply by 2^-24 (magic constant 0x33800000 in F32)
71+
// 3. FPU normalizes automatically, giving correct F32 representation
72+
// 4. Reinterpret as bits and apply sign
73+
74+
// Convert mantissa u32 → f32 (each lane independently)
75+
v128_t mant_f32x4 = wasm_f32x4_convert_u32x4(mant_u32x4);
76+
77+
// Multiply by 2^-24 (F32 hex: 0x33800000)
78+
v128_t magic_f32x4 = wasm_f32x4_splat(0x1p-24f); // 2^-24 in hex float notation
79+
v128_t denorm_normalized_f32x4 = wasm_f32x4_mul(mant_f32x4, magic_f32x4);
80+
81+
// Reinterpret f32x4 as u32x4 bits (v128_t is polymorphic - just assign)
82+
v128_t denorm_bits_u32x4 = denorm_normalized_f32x4;
83+
84+
// Apply sign (OR with sign bit, since denorm result is always positive)
85+
denorm_bits_u32x4 = wasm_v128_or(denorm_bits_u32x4, sign_f32_u32x4);
86+
87+
// Build Masks
88+
v128_t exp_zero_mask = wasm_i32x4_eq(exp_u32x4, wasm_i32x4_splat(0));
89+
v128_t mant_zero_mask = wasm_i32x4_eq(mant_u32x4, wasm_i32x4_splat(0));
90+
v128_t exp_max_mask = wasm_i32x4_eq(exp_u32x4, wasm_i32x4_splat(31));
91+
92+
v128_t is_zero_mask = wasm_v128_and(exp_zero_mask, mant_zero_mask); // exp=0 AND mant=0
93+
v128_t is_denormal_mask = wasm_v128_andnot(exp_zero_mask, mant_zero_mask); // exp=0 AND mant≠0
94+
95+
// Blend the results
96+
v128_t result_u32x4 = normal_bits_u32x4;
97+
98+
// Apply zero where exp=0 && mant=0
99+
result_u32x4 = wasm_v128_bitselect(zero_bits_u32x4, result_u32x4, is_zero_mask);
100+
101+
// Apply denormal where exp=0 && mant≠0
102+
result_u32x4 = wasm_v128_bitselect(denorm_bits_u32x4, result_u32x4, is_denormal_mask);
103+
104+
// Apply inf/NaN where exp=31
105+
result_u32x4 = wasm_v128_bitselect(inf_nan_bits_u32x4, result_u32x4, exp_max_mask);
106+
107+
nk_b128_vec_t result;
108+
result.v128 = result_u32x4;
109+
return result;
110+
}
111+
112+
#if defined(__cplusplus)
113+
}
114+
#endif
115+
116+
#endif // NK_TARGET_V128RELAXED
117+
#endif // NK_CAST_WASM_H

include/numkong/dot.h

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,21 @@ NK_INTERNAL void nk_dot_i8x32_finalize_sierra(
880880

881881
#endif // NK_TARGET_SIERRA
882882

883+
#if NK_TARGET_V128RELAXED
884+
/** @copydoc nk_dot_f32 */
885+
NK_PUBLIC void nk_dot_f32_wasm(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
886+
/** @copydoc nk_dot_f64 */
887+
NK_PUBLIC void nk_dot_f64_wasm(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
888+
/** @copydoc nk_dot_f16 */
889+
NK_PUBLIC void nk_dot_f16_wasm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
890+
/** @copydoc nk_dot_bf16 */
891+
NK_PUBLIC void nk_dot_bf16_wasm(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
892+
/** @copydoc nk_dot_i8 */
893+
NK_PUBLIC void nk_dot_i8_wasm(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
894+
/** @copydoc nk_dot_u8 */
895+
NK_PUBLIC void nk_dot_u8_wasm(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
896+
#endif // NK_TARGET_V128RELAXED
897+
883898
/**
884899
* @brief Returns the output dtype for dot products.
885900
*/
@@ -917,11 +932,14 @@ NK_INTERNAL nk_dtype_t nk_dot_output_dtype(nk_dtype_t dtype) {
917932
#include "numkong/dot/spacemit.h"
918933
#include "numkong/dot/sifive.h"
919934
#include "numkong/dot/xuantie.h"
935+
#include "numkong/dot/wasm.h"
920936

921937
#if !NK_DYNAMIC_DISPATCH
922938

923939
NK_PUBLIC void nk_dot_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result) {
924-
#if NK_TARGET_SPACEMIT
940+
#if NK_TARGET_V128RELAXED
941+
nk_dot_i8_wasm(a, b, n, result);
942+
#elif NK_TARGET_SPACEMIT
925943
nk_dot_i8_spacemit(a, b, n, result);
926944
#elif NK_TARGET_NEONSDOT
927945
nk_dot_i8_neonsdot(a, b, n, result);
@@ -936,7 +954,9 @@ NK_PUBLIC void nk_dot_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32
936954
#endif
937955
}
938956
NK_PUBLIC void nk_dot_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
939-
#if NK_TARGET_SPACEMIT
957+
#if NK_TARGET_V128RELAXED
958+
nk_dot_u8_wasm(a, b, n, result);
959+
#elif NK_TARGET_SPACEMIT
940960
nk_dot_u8_spacemit(a, b, n, result);
941961
#elif NK_TARGET_NEONSDOT
942962
nk_dot_u8_neonsdot(a, b, n, result);
@@ -973,7 +993,9 @@ NK_PUBLIC void nk_dot_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk
973993
#endif
974994
}
975995
NK_PUBLIC void nk_dot_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
976-
#if NK_TARGET_SIFIVE
996+
#if NK_TARGET_V128RELAXED
997+
nk_dot_f16_wasm(a, b, n, result);
998+
#elif NK_TARGET_SIFIVE
977999
nk_dot_f16_sifive(a, b, n, result);
9781000
#elif NK_TARGET_SPACEMIT
9791001
nk_dot_f16_spacemit(a, b, n, result);
@@ -992,7 +1014,9 @@ NK_PUBLIC void nk_dot_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_
9921014
#endif
9931015
}
9941016
NK_PUBLIC void nk_dot_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
995-
#if NK_TARGET_GENOA
1017+
#if NK_TARGET_V128RELAXED
1018+
nk_dot_bf16_wasm(a, b, n, result);
1019+
#elif NK_TARGET_GENOA
9961020
nk_dot_bf16_genoa(a, b, n, result);
9971021
#elif NK_TARGET_SPACEMIT
9981022
nk_dot_bf16_spacemit(a, b, n, result);
@@ -1071,7 +1095,9 @@ NK_PUBLIC void nk_dot_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n,
10711095
#endif
10721096
}
10731097
NK_PUBLIC void nk_dot_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result) {
1074-
#if NK_TARGET_SPACEMIT
1098+
#if NK_TARGET_V128RELAXED
1099+
nk_dot_f32_wasm(a, b, n, result);
1100+
#elif NK_TARGET_SPACEMIT
10751101
nk_dot_f32_spacemit(a, b, n, result);
10761102
#elif NK_TARGET_SVE
10771103
nk_dot_f32_sve(a, b, n, result);
@@ -1086,7 +1112,9 @@ NK_PUBLIC void nk_dot_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_
10861112
#endif
10871113
}
10881114
NK_PUBLIC void nk_dot_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
1089-
#if NK_TARGET_SPACEMIT
1115+
#if NK_TARGET_V128RELAXED
1116+
nk_dot_f64_wasm(a, b, n, result);
1117+
#elif NK_TARGET_SPACEMIT
10901118
nk_dot_f64_spacemit(a, b, n, result);
10911119
#elif NK_TARGET_SVE
10921120
nk_dot_f64_sve(a, b, n, result);

0 commit comments

Comments
 (0)