Skip to content

Commit a094bde

Browse files
committed
Caching improvements
1 parent bbfe28b commit a094bde

File tree

5 files changed

+27
-21
lines changed

5 files changed

+27
-21
lines changed

python/cuda_cccl/cuda/compute/_caching.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
except ImportError:
1515
from cuda.core.experimental import Device
1616

17-
from ._utils.protocols import get_dtype, get_shape
17+
from ._utils.protocols import get_dtype, get_shape, is_device_array
1818
from .typing import DeviceArrayLike, GpuStruct
1919

2020
# Registry thet maps type -> key function for extracting cache key
@@ -45,9 +45,16 @@ def _key_for(value: Any) -> Hashable:
4545
if value_type in _KEY_FUNCTIONS:
4646
return _KEY_FUNCTIONS[value_type](value)
4747

48+
# DeviceArrayLike is not a runtime-checkable protocol, so
49+
# we cannot isinstance() with it.
50+
if is_device_array(value):
51+
return _KEY_FUNCTIONS[DeviceArrayLike](value)
52+
4853
# Check for instance match (handles inheritance)
4954
for registered_type, keyer in _KEY_FUNCTIONS.items():
50-
if isinstance(value, registered_type):
55+
if registered_type is not DeviceArrayLike and isinstance(
56+
value, registered_type
57+
):
5158
return keyer(value)
5259

5360
# Fallback: use value directly (assumes it's hashable)
@@ -138,12 +145,10 @@ def register(self, type_: type, key_function: Callable[[Any], Hashable]) -> None
138145

139146

140147
def _make_hashable(value):
141-
from .typing import DeviceArrayLike
142-
143148
# duck-type check for numba.cuda.CUDADispatcher:
144149
if hasattr(value, "py_func") and callable(value.py_func):
145150
return CachableFunction(value.py_func)
146-
elif isinstance(value, DeviceArrayLike):
151+
elif is_device_array(value):
147152
# Ops with device arrays in globals/closures will be handled
148153
# by stateful op machinery, which enables updating the state
149154
# (pointers). Thus, we only cache on the dtype and shape of

python/cuda_cccl/cuda/compute/_jit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from numba.extending import lower_builtin, lower_cast
3535

3636
from ._caching import CachableFunction, cache_with_registered_key_functions
37+
from ._utils.protocols import is_device_array
3738
from .op import Op, OpAdapter
3839
from .typing import DeviceArrayLike
3940

@@ -768,7 +769,7 @@ def _detect_device_array_globals(func: Callable) -> List[Tuple[str, object]]:
768769

769770
for name in code.co_names:
770771
val = func.__globals__.get(name)
771-
if val is not None and isinstance(val, DeviceArrayLike):
772+
if val is not None and hasattr(val, "__cuda_array_interface__"):
772773
state_arrays.append((name, val))
773774

774775
return state_arrays
@@ -795,7 +796,7 @@ def _detect_device_array_closures(func: Callable) -> List[Tuple[str, object]]:
795796
for name, cell in zip(code.co_freevars, closure):
796797
try:
797798
val = cell.cell_contents
798-
if isinstance(val, DeviceArrayLike):
799+
if is_device_array(val):
799800
state_arrays.append((name, val))
800801
except ValueError:
801802
# Cell is empty

python/cuda_cccl/cuda/compute/_utils/protocols.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from ..typing import DeviceArrayLike, GpuStruct
1515

1616

17+
def is_device_array(obj: object) -> bool:
18+
"""Check if an object implements the `__cuda_array_interface__` protocol."""
19+
return hasattr(obj, "__cuda_array_interface__")
20+
21+
1722
def get_data_pointer(arr: DeviceArrayLike) -> int:
1823
# TODO: these are fast paths for CuPy and PyTorch until
1924
# we have a more general solution.

python/cuda_cccl/cuda/compute/algorithms/_scan.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
set_cccl_iterator_state,
1717
to_cccl_value_state,
1818
)
19-
from .._utils.protocols import get_data_pointer, validate_and_get_stream
19+
from .._utils.protocols import (
20+
get_data_pointer,
21+
is_device_array,
22+
validate_and_get_stream,
23+
)
2024
from .._utils.temp_storage_buffer import TempStorageBuffer
2125
from ..iterators._iterators import IteratorBase
2226
from ..op import OpAdapter, OpKind, make_op_adapter
@@ -29,7 +33,7 @@ def get_init_kind(
2933
match init_value:
3034
case None:
3135
return _bindings.InitKind.NO_INIT
32-
case _ if isinstance(init_value, DeviceArrayLike):
36+
case _ if is_device_array(init_value):
3337
return _bindings.InitKind.FUTURE_VALUE_INIT
3438
case _:
3539
return _bindings.InitKind.VALUE_INIT

python/cuda_cccl/cuda/compute/typing.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
#
44
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6-
from typing import Protocol, runtime_checkable
6+
from typing import Protocol
77

8-
import numpy as np
8+
from .struct import _Struct
99

1010

11-
@runtime_checkable
1211
class DeviceArrayLike(Protocol):
1312
"""
1413
Objects representing a device array, having a `.__cuda_array_interface__`
@@ -26,12 +25,4 @@ class StreamLike(Protocol):
2625
def __cuda_stream__(self) -> tuple[int, int]: ...
2726

2827

29-
@runtime_checkable
30-
class GpuStruct(Protocol):
31-
"""
32-
Type of instances of structs created with gpu_struct().
33-
"""
34-
35-
_data: np.ndarray
36-
__array_interface__: dict
37-
dtype: np.dtype
28+
GpuStruct = _Struct

0 commit comments

Comments
 (0)