Skip to content
4 changes: 2 additions & 2 deletions notebooks/dev_tutorials/extend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@
"# can't handle other inputs\n",
"def mul_to_multimul_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:\n",
" def is_cpu(x: Number | TensorProxy) -> bool:\n",
" if isinstance(a, TensorProxy):\n",
" return a.device.devicetype == DeviceType.CPU\n",
" if isinstance(x, TensorProxy):\n",
" return x.device.devicetype == DeviceType.CPU\n",
" return True\n",
"\n",
" return all(is_cpu(x) for x in (a, b))"
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def mul_to_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorP

def mul_to_multimul_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:
def is_cpu(x: Number | TensorProxy) -> bool:
if isinstance(a, TensorProxy):
return a.device.devicetype == devices.DeviceType.CPU
if isinstance(x, TensorProxy):
return x.device.devicetype == devices.DeviceType.CPU
return True

return all(is_cpu(x) for x in (a, b))
Expand Down
Loading