Skip to content

Commit 3d9be2f

Browse files
committed
WIP add test
1 parent ee050f2 commit 3d9be2f

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

thunder/tests/opinfos.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4473,6 +4473,28 @@ def make_nd_idx(dim_length: int, indices: int, ndim: int):
44734473
shape_ops.append(getitem_opinfo)
44744474

44754475

4476+
def setitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
4477+
for sample in getitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
4478+
tensor, key = sample.args
4479+
4480+
indexed_tensor = tensor[key]
4481+
value = make_tensor(indexed_tensor.shape, device=device, dtype=dtype, requires_grad=requires_grad)
4482+
yield SampleInput(tensor, key, value)
4483+
4484+
pre_broadcast_shape = tuple(random.choice((s, 1)) for s in indexed_tensor.shape)
4485+
value = make_tensor(pre_broadcast_shape, device=device, dtype=dtype, requires_grad=requires_grad)
4486+
yield SampleInput(tensor, key, value)
4487+
4488+
4489+
setitem_opinfo = OpInfo(
4490+
operator.setitem,
4491+
sample_input_generator=setitem_sample_generator,
4492+
torch_reference=operator.setitem,
4493+
numpy_reference=operator.setitem,
4494+
)
4495+
shape_ops.append(setitem_opinfo)
4496+
4497+
44764498
def movedim_sample_generator(op, device, dtype, requires_grad, **kwargs):
44774499
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
44784500

thunder/tests/test_grad.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"index_select",
4343
# Finite difference approximation doesn't work for this function
4444
"embedding",
45+
"setitem",
4546
"index_put",
4647
"batch_norm",
4748
"instance_norm",
@@ -689,6 +690,34 @@ def test_vjp_correctness_embedding_manual(op, device, dtype, executor, comp):
689690
comp(actual_out, out)
690691

691692

693+
@ops((get_opinfo("setitem"),), supported_dtypes=(dtypes.float64,))
694+
def test_vjp_correctness_setitem_manual(op, device, dtype, executor, comp):
695+
for sample in op.sample_inputs(device, dtype, requires_grad=True):
696+
697+
def torch_reference(tensor, idx, value):
698+
cloned = tensor * 1
699+
op.torch_reference(cloned, idx, value)
700+
return cloned
701+
702+
def op_fn(tensor, idx, value):
703+
cloned = tensor * 1
704+
op.op(cloned, idx, value)
705+
return cloned
706+
707+
out = torch_reference(*sample.args, **sample.kwargs)
708+
v = make_tensor_like(out)
709+
expected = torch.autograd.grad(out, (sample.args[0], sample.args[2]), v)
710+
711+
# Compute vjp result using Thunder
712+
flat_op, flat_args, spec = flatten_func(op_fn, sample.args, sample.kwargs)
713+
initial_trace = thunder.trace()(vjp(flat_op), flat_args, (v,))
714+
jfn = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)
715+
actual_out, actual_grad = jfn(flat_args, (v,))
716+
comp(actual_out, out)
717+
comp(actual_grad[0], expected[0])
718+
comp(actual_grad[-1], expected[1])
719+
720+
692721
@ops((op for op in opinfos if op.name == "type_as"), supported_dtypes=(dtypes.float64,))
693722
def test_vjp_correctness_type_as_manual(op, device, dtype, executor, comp):
694723
for sample in op.sample_inputs(device, dtype, requires_grad=True):

thunder/torch/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,13 @@ def _copy_(a, b, /):
269269
return prims.copy_(b, a, grad_enabled=cd.is_grad_enabled if cd is not None else False)
270270

271271

272+
def _clone_via_copy(t: TensorProxy) -> TensorProxy:
273+
"""Produces a functional clone using an explicit copy instead of prims.clone."""
274+
cd = get_compile_data()
275+
buf = prims.empty(t.shape, device=t.device, dtype=t.dtype)
276+
return prims.copy_(t, buf, grad_enabled=cd.is_grad_enabled if cd is not None else False)
277+
278+
272279
@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
273280
def copy_(a, b, /):
274281
return _copy_(a, b)
@@ -2609,6 +2616,20 @@ def mod_(a, b):
26092616

26102617
@torchsymbol(torch.mul, is_method=True)
26112618
def mul(a, b, /):
2619+
def _is_scalar_one(x):
2620+
if isinstance(x, NumberProxy):
2621+
try:
2622+
return pyval(x) == 1
2623+
except Exception:
2624+
return False
2625+
if isinstance(x, bool):
2626+
return x is True
2627+
return isinstance(x, (int, float)) and x == 1
2628+
2629+
if isinstance(a, TensorProxy) and _is_scalar_one(b):
2630+
return _clone_via_copy(a)
2631+
if isinstance(b, TensorProxy) and _is_scalar_one(a):
2632+
return _clone_via_copy(b)
26122633
return clang.mul(a, b)
26132634

26142635

0 commit comments

Comments
 (0)