|
42 | 42 | "index_select", |
43 | 43 | # Finite difference approximation doesn't work for this function |
44 | 44 | "embedding", |
| 45 | + "setitem", |
45 | 46 | "index_put", |
46 | 47 | "batch_norm", |
47 | 48 | "instance_norm", |
@@ -689,6 +690,34 @@ def test_vjp_correctness_embedding_manual(op, device, dtype, executor, comp): |
689 | 690 | comp(actual_out, out) |
690 | 691 |
|
691 | 692 |
|
| 693 | +@ops((get_opinfo("setitem"),), supported_dtypes=(dtypes.float64,)) |
| 694 | +def test_vjp_correctness_setitem_manual(op, device, dtype, executor, comp): |
| 695 | + for sample in op.sample_inputs(device, dtype, requires_grad=True): |
| 696 | + |
| 697 | + def torch_reference(tensor, idx, value): |
| 698 | + cloned = tensor * 1 |
| 699 | + op.torch_reference(cloned, idx, value) |
| 700 | + return cloned |
| 701 | + |
| 702 | + def op_fn(tensor, idx, value): |
| 703 | + cloned = tensor * 1 |
| 704 | + op.op(cloned, idx, value) |
| 705 | + return cloned |
| 706 | + |
| 707 | + out = torch_reference(*sample.args, **sample.kwargs) |
| 708 | + v = make_tensor_like(out) |
| 709 | + expected = torch.autograd.grad(out, (sample.args[0], sample.args[2]), v) |
| 710 | + |
| 711 | + # Compute vjp result using Thunder |
| 712 | + flat_op, flat_args, spec = flatten_func(op_fn, sample.args, sample.kwargs) |
| 713 | + initial_trace = thunder.trace()(vjp(flat_op), flat_args, (v,)) |
| 714 | + jfn = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True) |
| 715 | + actual_out, actual_grad = jfn(flat_args, (v,)) |
| 716 | + comp(actual_out, out) |
| 717 | + comp(actual_grad[0], expected[0]) |
| 718 | + comp(actual_grad[-1], expected[1]) |
| 719 | + |
| 720 | + |
692 | 721 | @ops((op for op in opinfos if op.name == "type_as"), supported_dtypes=(dtypes.float64,)) |
693 | 722 | def test_vjp_correctness_type_as_manual(op, device, dtype, executor, comp): |
694 | 723 | for sample in op.sample_inputs(device, dtype, requires_grad=True): |
|
0 commit comments