|
| 1 | +import subprocess |
| 2 | +import sys |
| 3 | +import tempfile |
| 4 | +from argparse import ArgumentParser |
| 5 | +from collections.abc import Iterator |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +# Workaround for python <3.10, escape characters can't appear in f-strings. |
| 9 | +# Although we require 3.10 in some places, the formatter complains without this. |
| 10 | +newline = "\n" |
| 11 | + |
| 12 | +backslash = '\\' |
| 13 | + |
| 14 | + |
| 15 | +def indent(s): |
| 16 | + return "\n".join(f" {line}" if line else "" for line in s.split("\n")) |
| 17 | + |
| 18 | + |
| 19 | +# skips None for convenience |
| 20 | +def instruction(*args): |
| 21 | + return f"({' '.join(arg for arg in args if arg is not None)})" |
| 22 | + |
| 23 | + |
| 24 | +def atomic_instruction(op, memid, immediate, /, *args, drop): |
| 25 | + if drop: |
| 26 | + return f"(drop {instruction(op, memid, immediate, *args)})" |
| 27 | + return instruction(op, memid, immediate, *args) |
| 28 | + |
| 29 | + |
| 30 | +all_ops = [ |
| 31 | + ("i32.atomic.load", "(i32.const 51)", True), |
| 32 | + ("i32.atomic.store", "(i32.const 51) (i32.const 51)", False), |
| 33 | +] |
| 34 | + |
| 35 | + |
| 36 | +def drop_atomic(instruction): |
| 37 | + first, atomic, last = instruction.partition(".atomic") |
| 38 | + return first + last |
| 39 | + |
| 40 | + |
| 41 | +non_atomic_ops = [(drop_atomic(instruction), arg, drop) for instruction, arg, drop in all_ops if "rmw" not in instruction] |
| 42 | + |
| 43 | + |
| 44 | +def func(memid, immediate, ops=all_ops): |
| 45 | + return f'''(func ${immediate if immediate is not None else "no_immediate"}{"_with_memid" if memid is not None else "_without_memid"} |
| 46 | +{indent(newline.join(atomic_instruction(op, memid, immediate, arg, drop=should_drop) for op, arg, should_drop in ops))} |
| 47 | +)''' |
| 48 | + |
| 49 | + |
| 50 | +def module(*statements): |
| 51 | + return f'''(module |
| 52 | +{newline.join(map(indent, statements))} |
| 53 | +)''' |
| 54 | + |
| 55 | + |
| 56 | +def module_binary(bin): |
| 57 | + return f'''(module binary "{''.join(f'{backslash}{byte:02x}' for byte in bin)}")''' |
| 58 | + |
| 59 | + |
| 60 | +def assert_invalid(module, reason): |
| 61 | + return f'''(assert_invalid {module} "{reason}")''' |
| 62 | + |
| 63 | + |
| 64 | +def generate_atomic_spec_test(): |
| 65 | + # Declare two memories so we have control over whether the memory immediate is printed |
| 66 | + # A memory immediate of 0 is allowed to be omitted. |
| 67 | + return module( |
| 68 | + "(memory 1 1 shared)", |
| 69 | + "(memory 1 1 shared)", |
| 70 | + "", |
| 71 | + "\n\n".join([f'{func(memid, ordering)}' for memid in [None, "1"] for ordering in [None, "acqrel", "seqcst"]])) |
| 72 | + |
| 73 | + |
| 74 | +def to_binary(wasm_as, wat: str) -> bytes: |
| 75 | + with tempfile.NamedTemporaryFile(mode="w+") as input, tempfile.NamedTemporaryFile(mode="rb") as output: |
| 76 | + input.write(wat) |
| 77 | + input.seek(0) |
| 78 | + |
| 79 | + proc = subprocess.run([wasm_as, "--enable-multimemory", "--enable-threads", "--enable-relaxed-atomics", input.name, "-o", output.name], capture_output=True) |
| 80 | + try: |
| 81 | + proc.check_returncode() |
| 82 | + except Exception: |
| 83 | + print(proc.stderr.decode('utf-8'), end="", file=sys.stderr) |
| 84 | + raise |
| 85 | + |
| 86 | + return output.read() |
| 87 | + |
| 88 | + |
| 89 | +def findall(bytes, byte): |
| 90 | + ix = -1 |
| 91 | + while ((ix := bytes.find(byte, ix + 1)) != -1): |
| 92 | + yield ix |
| 93 | + |
| 94 | + |
| 95 | +def read_unsigned_leb(bytes, start): |
| 96 | + """Returns (bytes read, value)""" |
| 97 | + ret = 0 |
| 98 | + for i, byte in enumerate(bytes[start:]): |
| 99 | + ret |= (byte & ~(1 << 7)) << (7 * i) |
| 100 | + if not byte & (1 << 7): |
| 101 | + return i + 1, ret |
| 102 | + raise ValueError("Unexpected end of input, continuation bit was set for the last byte.") |
| 103 | + |
| 104 | + |
| 105 | +def to_unsigned_leb(num): |
| 106 | + ret = bytearray() |
| 107 | + |
| 108 | + if num == 0: |
| 109 | + ret = bytearray() |
| 110 | + ret.append(0) |
| 111 | + return ret |
| 112 | + ret = bytearray() |
| 113 | + while num > 0: |
| 114 | + rem = num >> 7 |
| 115 | + ret.append((num & 0x7F) | (bool(rem) << 7)) |
| 116 | + |
| 117 | + num = rem |
| 118 | + return ret |
| 119 | + |
| 120 | + |
| 121 | +def unsigned_leb_add(bytes: bytearray, start, add) -> int: |
| 122 | + """Returns number of bytes added""" |
| 123 | + l, decoded = read_unsigned_leb(bytes, start) |
| 124 | + added = to_unsigned_leb(decoded + add) |
| 125 | + |
| 126 | + bytes[start:start + l] = added[:l] |
| 127 | + |
| 128 | + if len(added) > l: |
| 129 | + for i, b in enumerate(added[l:], start=l): |
| 130 | + bytes.insert(i, b) |
| 131 | + |
| 132 | + return len(added) - l |
| 133 | + |
| 134 | + |
| 135 | +def unsigned_leb_subtract(bytes, start, sub): |
| 136 | + l, decoded = read_unsigned_leb(bytes, start) |
| 137 | + subbed = to_unsigned_leb(decoded - sub) |
| 138 | + |
| 139 | + bytes[start:start + len(subbed)] = subbed |
| 140 | + |
| 141 | + diff = l - len(subbed) |
| 142 | + for _ in range(diff): |
| 143 | + bytes.pop(start + len(subbed)) |
| 144 | + |
| 145 | + return -diff |
| 146 | + |
| 147 | + |
| 148 | +def iterate_sections(bytes) -> Iterator[bytearray]: |
| 149 | + bytes = bytes.removeprefix(b"\00asm\01\00\00\00") |
| 150 | + start = 0 |
| 151 | + while True: |
| 152 | + read, size = read_unsigned_leb(bytes, start + 1) |
| 153 | + |
| 154 | + # section op + section size + body |
| 155 | + yield bytearray(bytes[start:start + 1 + read + size]) |
| 156 | + start += 1 + read + size |
| 157 | + if start > len(bytes): |
| 158 | + raise ValueError("not expected", start, len(bytes)) |
| 159 | + elif start == len(bytes): |
| 160 | + return |
| 161 | + |
| 162 | + |
| 163 | +def iterate_functions(bytes) -> Iterator[bytearray]: |
| 164 | + read, size = read_unsigned_leb(bytes, 1) |
| 165 | + read2, size2 = read_unsigned_leb(bytes, 1 + read) |
| 166 | + section_body = bytes[1 + read + read2:] |
| 167 | + |
| 168 | + start = 0 |
| 169 | + while True: |
| 170 | + read, size = read_unsigned_leb(section_body, start) |
| 171 | + yield bytearray(section_body[start:start + read + size]) |
| 172 | + start += read + size |
| 173 | + if start > len(section_body): |
| 174 | + raise ValueError("not expected", start, len(section_body)) |
| 175 | + elif start == len(section_body): |
| 176 | + return |
| 177 | + |
| 178 | + |
| 179 | +def binary_tests(b: bytes) -> bytes: |
| 180 | + updated_tests = [b"\00asm\01\00\00\00"] |
| 181 | + |
| 182 | + for section in iterate_sections(b): |
| 183 | + if section[0] != 0x0a: |
| 184 | + updated_tests.append(section) |
| 185 | + continue |
| 186 | + |
| 187 | + bytes_read, size = read_unsigned_leb(section, 1) |
| 188 | + _, func_count = read_unsigned_leb(section, 1 + bytes_read) |
| 189 | + |
| 190 | + updated_code_section = bytearray() |
| 191 | + updated_code_section.append(0x0a) |
| 192 | + updated_code_section += to_unsigned_leb(size) |
| 193 | + |
| 194 | + updated_code_section += to_unsigned_leb(func_count) |
| 195 | + |
| 196 | + section_bytes_added = 0 |
| 197 | + for i, func in enumerate(iterate_functions(section)): |
| 198 | + # TODO: this is wrong if the function size is 0xfe |
| 199 | + ix = func.find(0xfe) |
| 200 | + if ix == -1: |
| 201 | + raise ValueError("Didn't find atomic operation") |
| 202 | + if i not in (2, 5): |
| 203 | + updated_code_section += func |
| 204 | + continue |
| 205 | + if func[ix + 2] & (1 << 5): |
| 206 | + raise ValueError("Memory immediate was already set.") |
| 207 | + func_bytes_added = 0 |
| 208 | + for i in findall(func, 0xfe): |
| 209 | + func[i + 2] |= (1 << 5) |
| 210 | + |
| 211 | + # ordering comes after mem idx |
| 212 | + has_mem_idx = bool(func[i + 2] & (1 << 6)) |
| 213 | + func.insert(i + 3 + has_mem_idx, 0x00) |
| 214 | + |
| 215 | + func_bytes_added += 1 |
| 216 | + |
| 217 | + # adding to the func byte size might have added a byte |
| 218 | + section_bytes_added += unsigned_leb_add(func, 0, func_bytes_added) |
| 219 | + section_bytes_added += func_bytes_added |
| 220 | + |
| 221 | + updated_code_section += func |
| 222 | + |
| 223 | + _ = unsigned_leb_add(updated_code_section, 1, section_bytes_added) |
| 224 | + updated_tests.append(updated_code_section) |
| 225 | + |
| 226 | + return b''.join(updated_tests) |
| 227 | + |
| 228 | + |
| 229 | +def failing_test(instruction, arg, /, memidx, drop): |
| 230 | + """Module assertion that sets a memory ordering immediate for a non-atomic instruction""" |
| 231 | + |
| 232 | + func = f"(func ${''.join(filter(str.isalnum, instruction))} {atomic_instruction(instruction, memidx, 'acqrel', arg, drop=drop)})" |
| 233 | + return assert_invalid(module("(memory 1 1 shared)", "", func), f"Can't set memory ordering for non-atomic {instruction}") |
| 234 | + |
| 235 | + |
| 236 | +def failing_tests(): |
| 237 | + text_tests = "\n\n".join( |
| 238 | + failing_test(op, arg, memidx=None, drop=should_drop) |
| 239 | + for op, arg, should_drop in non_atomic_ops |
| 240 | + ) |
| 241 | + |
| 242 | + return text_tests |
| 243 | + |
| 244 | + |
| 245 | +def main(): |
| 246 | + parser = ArgumentParser() |
| 247 | + parser.add_argument("--wasm-as", default=Path("bin/wasm-as"), type=Path) |
| 248 | + |
| 249 | + args = parser.parse_args() |
| 250 | + |
| 251 | + wat = generate_atomic_spec_test() |
| 252 | + bin = binary_tests(to_binary(args.wasm_as, wat)) |
| 253 | + print(wat) |
| 254 | + print(module_binary(bin)) |
| 255 | + print() |
| 256 | + print(failing_tests()) |
| 257 | + print() |
| 258 | + |
| 259 | + |
| 260 | +if __name__ == "__main__": |
| 261 | + main() |
0 commit comments