Skip to content

Commit 7011cc2

Browse files
committed
Update tests
1 parent 3d7fa9c commit 7011cc2

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

thunder/tests/opinfos.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4474,17 +4474,38 @@ def make_nd_idx(dim_length: int, indices: int, ndim: int):
44744474

44754475

44764476
def setitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
4477-
for sample in getitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
4478-
tensor, key = sample.args
4477+
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
44794478

4480-
indexed_tensor = tensor[key]
4481-
# getitem already has lots of cases, and doubling it is too time-consuming
4482-
# value = make_tensor(indexed_tensor.shape, device=device, dtype=dtype, requires_grad=requires_grad)
4479+
def _make_setitem_sample(tensor, key):
4480+
indexed_shape = tensor[key].shape
4481+
4482+
# Tests for getitem are already slow, and doubling them is too time-consuming
4483+
# value = make_tensor(indexed_shape, device=device, dtype=dtype, requires_grad=requires_grad)
44834484
# yield SampleInput(tensor, key, value)
44844485

4485-
pre_broadcast_shape = tuple(random.choice((s, 1)) for s in indexed_tensor.shape)
4486-
value = make_tensor(pre_broadcast_shape, device=device, dtype=dtype, requires_grad=requires_grad)
4487-
yield SampleInput(tensor, key, value)
4486+
pre_broadcast_shape = tuple(random.choice((s, 1)) for s in indexed_shape)
4487+
pre_broadcast_value = make_tensor(pre_broadcast_shape, device=device, dtype=dtype, requires_grad=requires_grad)
4488+
return SampleInput(tensor, key, pre_broadcast_value)
4489+
4490+
for sample in getitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
4491+
tensor, key = sample.args
4492+
yield _make_setitem_sample(tensor, key)
4493+
4494+
# Boolean mask indexing
4495+
boolean_mask_cases = [
4496+
((6,), (torch.tensor([True, False, True, False, True, False]),)),
4497+
((2, 3), (torch.tensor([[True, False, True], [False, True, False]]),)),
4498+
((2, 3, 4), ([False, True], [False, True, False], slice(None))),
4499+
((2, 3, 4), (torch.tensor([True, False]), [1, 1], slice(None))),
4500+
((2, 3, 4), (torch.tensor([False, False]), [1, 1], slice(None))),
4501+
((2, 3, 4), (1, torch.tensor([True, False, True]), slice(None))),
4502+
((2, 3), (torch.tensor([True, False]), None, [0, 2])),
4503+
((4, 2, 3), (Ellipsis, [False, True, False])),
4504+
]
4505+
4506+
for shape, key in boolean_mask_cases:
4507+
tensor = make(shape)
4508+
yield _make_setitem_sample(tensor, key)
44884509

44894510

44904511
setitem_opinfo = OpInfo(

thunder/tests/test_grad.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,21 +704,32 @@ def op_fn(tensor, idx, value):
704704
op.op(cloned, idx, value)
705705
return cloned
706706

707-
args_ref = (sample.args[0].detach().clone().requires_grad_(True),) + sample.args[1:]
708-
out = torch_reference(*args_ref, **sample.kwargs)
707+
tensor, key, value = sample.args
708+
assert not sample.kwargs
709+
710+
tensor_ref = tensor.detach().clone().requires_grad_(True)
711+
out = torch_reference(tensor_ref, key, value)
709712
v = make_tensor_like(out)
710-
expected = torch.autograd.grad(out, (args_ref[0], args_ref[2]), v)
713+
expected = torch.autograd.grad(out, (tensor_ref, value), v)
714+
715+
flat_op, flat_args, spec = flatten_func(op_fn, (tensor, key, value), {})
716+
717+
t_key = key if isinstance(key, tuple) else (key,)
718+
if any(isinstance(k, (torch.Tensor, Sequence)) and torch.tensor(k).dtype == torch.bool for k in t_key):
719+
with pytest.raises(NotImplementedError):
720+
executor.make_callable(flat_op, disable_torch_autograd=True)(*flat_args)
721+
with pytest.raises(NotImplementedError):
722+
vjp(flat_op)(flat_args, (v,))
723+
continue
711724

712-
# Compute vjp result using Thunder
713-
flat_op, flat_args, spec = flatten_func(op_fn, sample.args, sample.kwargs)
714725
initial_trace = thunder.trace()(vjp(flat_op), flat_args, (v,))
715726
jfn = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)
716727
actual_out, actual_grad = jfn(flat_args, (v,))
717728

718729
# With advanced indexing, an element may be assigned multiple times and the assignment order is not guaranteed.
719730
# comp(actual_out, out)
720731

721-
comp(sample.args[0], args_ref[0])
732+
comp(tensor, tensor_ref)
722733
comp(actual_grad[0], expected[0])
723734
comp(actual_grad[-1], expected[1])
724735

0 commit comments

Comments
 (0)