Skip to content

Commit fc22f44

Browse files
authored
Add string hashing support for dict keys (#391)
This PR adds hashing support for strings, enabling them to be used as dictionary keys. The implementation includes a cached hash value in the string structure and uses the FNV-1a hashing algorithm. Code written by claude following my precise instructions, and reviewed by me.
2 parents db69c10 + ab73992 commit fc22f44

File tree

8 files changed

+59
-11
lines changed

8 files changed

+59
-11
lines changed

spy/backend/c/cwriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def fmt_expr_StrConst(self, const: ast.StrConst) -> C.Expr:
243243
# generate the following:
244244
#
245245
# // global declarations
246-
# static spy_Str SPY_g_str0 = {5, "hello"};
246+
# static spy_Str SPY_g_str0 = {5, 0, "hello"};
247247
# ...
248248
# // literal expr
249249
# &SPY_g_str0 /* "hello" */
@@ -258,7 +258,7 @@ def fmt_expr_StrConst(self, const: ast.StrConst) -> C.Expr:
258258
v = self.cmodw.new_global_var("str") # SPY_g_str0
259259
n = len(utf8)
260260
lit = C.Literal.from_bytes(utf8)
261-
init = "{%d, %s}" % (n, lit)
261+
init = "{%d, 0, %s}" % (n, lit)
262262
self.cmodw.tbc_globals.wl(f"static spy_Str {v} = {init};")
263263
#
264264
# shortstr is what we show in the comment, with a length limit

spy/libspy/include/spy/str.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
typedef struct {
88
size_t length;
9+
int32_t hash;
910
const char utf8[];
1011
} spy_Str;
1112

@@ -27,12 +28,15 @@ spy_Str *WASM_EXPORT(spy_str_getitem)(spy_Str *s, int32_t i);
2728

2829
int32_t WASM_EXPORT(spy_str_len)(spy_Str *s);
2930

31+
int32_t WASM_EXPORT(spy_str_hash)(spy_Str *s);
32+
3033
#define spy_operator$str_add spy_str_add
3134
#define spy_operator$str_mul spy_str_mul
3235
#define spy_operator$str_eq spy_str_eq
3336
#define spy_operator$str_ne spy_str_ne
3437
#define spy_builtins$str$__getitem__ spy_str_getitem
3538
#define spy_builtins$str$__len__ spy_str_len
39+
#define spy_builtins$hash_str spy_str_hash
3640

3741
// __str__ methods of common builtin types
3842
spy_Str *spy_builtins$i32$__str__(int32_t x);

spy/libspy/src/str.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ spy_str_alloc(size_t length) {
77
size_t size = sizeof(spy_Str) + length;
88
spy_Str *res = (spy_Str *)spy_GcAlloc(size).p;
99
res->length = length;
10+
res->hash = 0;
1011
return res;
1112
}
1213

@@ -61,6 +62,25 @@ spy_str_len(spy_Str *s) {
6162
return (int32_t)s->length;
6263
}
6364

65+
int32_t
66+
spy_str_hash(spy_Str *s) {
67+
if (s->hash != 0)
68+
return s->hash;
69+
// FNV-1a hash
70+
uint32_t h = 2166136261u;
71+
for (size_t i = 0; i < s->length; i++) {
72+
h ^= (uint8_t)s->utf8[i];
73+
h *= 16777619u;
74+
}
75+
int32_t result = (int32_t)h;
76+
if (result == -1)
77+
result = -2;
78+
if (result == 0)
79+
result = 1;
80+
s->hash = result;
81+
return result;
82+
}
83+
6484
// Helper function to format and convert to spy_Str
6585
// XXX probably it would be better to implement it directly, instead of
6686
// bringing in all the code needed to support sprintf()

spy/tests/stdlib/test_dict.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,19 @@ def foo() -> dict[i32, i32]:
291291
x = mod.foo()
292292
assert x[42] == 100
293293

294+
def test_str_keys(self):
295+
src = """
296+
from _dict import dict
297+
298+
def test() -> i32:
299+
d = dict[str, i32]()
300+
d["hello"] = 1
301+
d["world"] = 2
302+
return d["hello"] + d["world"]
303+
"""
304+
mod = self.compile(src)
305+
assert mod.test() == 3
306+
294307
def test_literal_mixed_value_types_key_value(self):
295308
# useful for mixed type support
296309
# type of x must be i32

spy/tests/test_libspy.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ def mk_spy_Str(utf8: bytes) -> bytes:
1111
1212
For example, for b'hello' we have the following in-memory repr:
1313
<i 4 bytes of length, little endian
14+
i 4 bytes of hash (0 for uncached)
1415
5s 5 bytes of data (b'hello')
1516
"""
1617
n = len(utf8)
17-
fmt = f"<i{n}s"
18-
return struct.pack(fmt, n, utf8)
18+
fmt = f"<ii{n}s"
19+
return struct.pack(fmt, n, 0, utf8)
1920

2021

2122
class TestLibSPy(CTest):
@@ -47,7 +48,7 @@ def test_str(self):
4748
src = r"""
4849
#include <spy.h>
4950
50-
spy_Str H = {6, "hello "};
51+
spy_Str H = {6, 0, "hello "};
5152
5253
spy_Str *mk_W(void) {
5354
spy_Str *s = spy_str_alloc(5);
@@ -58,13 +59,13 @@ def test_str(self):
5859
test_wasm = self.c_compile(src, exports=["H", "mk_W"])
5960
ll = LLSPyInstance.from_file(test_wasm)
6061
ptr_H = ll.read_global("H")
61-
assert ll.mem.read(ptr_H, 10) == mk_spy_Str(b"hello ")
62+
assert ll.mem.read(ptr_H, 14) == mk_spy_Str(b"hello ")
6263
#
6364
ptr_W = ll.call("mk_W")
64-
assert ll.mem.read(ptr_W, 9) == mk_spy_Str(b"world")
65+
assert ll.mem.read(ptr_W, 13) == mk_spy_Str(b"world")
6566
#
6667
ptr_HW = ll.call("spy_str_add", ptr_H, ptr_W)
67-
assert ll.mem.read(ptr_HW, 15) == mk_spy_Str(b"hello world")
68+
assert ll.mem.read(ptr_HW, 19) == mk_spy_Str(b"hello world")
6869

6970
def test_debug_log(self):
7071
src = r"""

spy/tests/wasm_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def to_py_result(self, w_T: W_Type, res: Any) -> Any:
136136
# res is a spy_Str*
137137
addr = res
138138
length = self.ll.mem.read_i32(addr)
139-
utf8 = self.ll.mem.read(addr + 4, length)
139+
utf8 = self.ll.mem.read(addr + 8, length)
140140
return utf8.decode("utf-8")
141141
elif w_T is RB.w_RawBuffer:
142142
# res is a spy_RawBuffer*

spy/vm/modules/builtins.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ def w_hash_bool(vm: "SPyVM", w_x: W_Bool) -> W_I32:
180180
assert False, "unreachable"
181181

182182

183+
@BUILTINS.builtin_func
184+
def w_hash_str(vm: "SPyVM", w_x: W_Str) -> W_I32:
185+
assert isinstance(w_x, W_Str)
186+
res = vm.ll.call("spy_str_hash", w_x.ptr)
187+
return vm.wrap(res)
188+
189+
183190
@BUILTINS.builtin_func(color="blue", kind="metafunc")
184191
def w_hash(vm: "SPyVM", wam_obj: W_MetaArg) -> W_OpSpec:
185192
w_T = wam_obj.w_static_T
@@ -191,6 +198,8 @@ def w_hash(vm: "SPyVM", wam_obj: W_MetaArg) -> W_OpSpec:
191198
return W_OpSpec(B.w_hash_u8)
192199
elif w_T is B.w_bool:
193200
return W_OpSpec(B.w_hash_bool)
201+
elif w_T is B.w_str:
202+
return W_OpSpec(B.w_hash_str)
194203

195204
if w_fn := w_T.lookup_func("__hash__"):
196205
w_opspec = vm.fast_metacall(w_fn, [wam_obj])

spy/vm/str.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def ll_spy_Str_new(ll: LLWasmInstance, s: str) -> int:
2121
utf8 = s.encode("utf-8")
2222
length = len(utf8)
2323
ptr = ll.call("spy_str_alloc", length)
24-
ll.mem.write(ptr + 4, utf8)
24+
ll.mem.write(ptr + 8, utf8)
2525
return ptr
2626

2727

@@ -34,6 +34,7 @@ class W_Str(W_Object):
3434
resides in the linear memory of the VM:
3535
typedef struct {
3636
size_t length;
37+
int32_t hash;
3738
const char utf8[];
3839
} spy_Str;
3940
"""
@@ -59,7 +60,7 @@ def get_length(self) -> int:
5960

6061
def get_utf8(self) -> bytes:
6162
length = self.get_length()
62-
ba = self.vm.ll.mem.read(self.ptr + 4, length)
63+
ba = self.vm.ll.mem.read(self.ptr + 8, length)
6364
return bytes(ba)
6465

6566
def _as_str(self) -> str:

0 commit comments

Comments
 (0)