|
14 | 14 | except ImportError: |
15 | 15 | from cuda.core.experimental import Device |
16 | 16 |
|
17 | | -from ._utils.protocols import get_dtype, get_shape |
| 17 | +from ._utils.protocols import get_dtype, get_shape, is_device_array |
18 | 18 | from .typing import DeviceArrayLike, GpuStruct |
19 | 19 |
|
20 | 20 | # Registry thet maps type -> key function for extracting cache key |
@@ -45,9 +45,16 @@ def _key_for(value: Any) -> Hashable: |
45 | 45 | if value_type in _KEY_FUNCTIONS: |
46 | 46 | return _KEY_FUNCTIONS[value_type](value) |
47 | 47 |
|
| 48 | + # DeviceArrayLike is not a runtime-checkable protocol, so |
| 49 | + # we cannot isinstance() with it. |
| 50 | + if is_device_array(value): |
| 51 | + return _KEY_FUNCTIONS[DeviceArrayLike](value) |
| 52 | + |
48 | 53 | # Check for instance match (handles inheritance) |
49 | 54 | for registered_type, keyer in _KEY_FUNCTIONS.items(): |
50 | | - if isinstance(value, registered_type): |
| 55 | + if registered_type is not DeviceArrayLike and isinstance( |
| 56 | + value, registered_type |
| 57 | + ): |
51 | 58 | return keyer(value) |
52 | 59 |
|
53 | 60 | # Fallback: use value directly (assumes it's hashable) |
@@ -138,12 +145,10 @@ def register(self, type_: type, key_function: Callable[[Any], Hashable]) -> None |
138 | 145 |
|
139 | 146 |
|
140 | 147 | def _make_hashable(value): |
141 | | - from .typing import DeviceArrayLike |
142 | | - |
143 | 148 | # duck-type check for numba.cuda.CUDADispatcher: |
144 | 149 | if hasattr(value, "py_func") and callable(value.py_func): |
145 | 150 | return CachableFunction(value.py_func) |
146 | | - elif isinstance(value, DeviceArrayLike): |
| 151 | + elif is_device_array(value): |
147 | 152 | # Ops with device arrays in globals/closures will be handled |
148 | 153 | # by stateful op machinery, which enables updating the state |
149 | 154 | # (pointers). Thus, we only cache on the dtype and shape of |
|
0 commit comments