From efa77e4addaf98ea333fa12c0c2f3d3dccaa0388 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 09:33:10 -0800 Subject: [PATCH 01/23] Do not skip return stmt in update_aliases.py --- thunder/core/update_aliases.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index de8ec89604..37489cf70a 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -51,8 +51,6 @@ def _is_view_creation_op(bsym): def _involves_viewed_args(bsym, viewed): - if bsym.sym.id == prims.PrimIDs.RETURN: - return False return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args) From 6a4fbbbf0889da9f56026af3cc1a2c8a24220004 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 09:42:42 -0800 Subject: [PATCH 02/23] Make copy_ DCE'd by default --- thunder/common.py | 2 +- thunder/core/prims.py | 2 +- thunder/core/transform_common.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/thunder/common.py b/thunder/common.py index 487cb28e07..8863c3ec42 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -619,7 +619,7 @@ def wait_for_future(f: FutureTensorProxy) -> TensorProxy: # TODO Stop calling this here and make it a separate trace in the sequence # of traces if use_dce: - trace = dce(trace) + trace = dce(trace, keep_inplace_ops=True) finally: # Resets contexts diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 90e1f3689b..185b331b37 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4333,7 +4333,7 @@ def copy__meta( return TensorProxy(like=copy_to) -copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.DONT_DCE,)) +copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.IN_PLACE,)) def bitcast_meta( diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bc14a671f2..63dfff56a1 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -142,7 +142,7 @@ def keep_or_swap(p): # that only produce non-proxy objects # NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return # all the needed proxies of the input trace -def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: +def dce(trace: Trace, needed_proxies: None | set[Variable] = None, keep_inplace_ops: bool = False) -> Trace: start_time_ns = time.perf_counter_ns() producer_map: ProxyDict = producers(trace) @@ -159,6 +159,8 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: # Preserves symbols that should never be collected if has_tags(bsym, {prims.OpTags.DONT_DCE}): needed = True + elif keep_inplace_ops and has_tags(bsym, {prims.OpTags.IN_PLACE}): + needed = True else: needed = False From edac5c5c825eb5f619b17453a288674c51def8c0 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 10:51:30 -0800 Subject: [PATCH 03/23] Tag torchex.copy_ as IN_PLACE instead Or we could perhaps just remove tags on it --- thunder/executors/torchex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index d6fb0b31c1..4c9b5627c3 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2367,7 +2367,7 @@ def _copy__impl(copy_from, copy_to, grad_enabled): copy_ = ex.register_operator( - "copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl, module=torch.Tensor + "copy_", meta=prims.copy_, tags=(prims.OpTags.IN_PLACE,), fn=_copy__impl, module=torch.Tensor ) _register_implementation(prims.copy_, copy_, checker=_always_executable) From 6fefee215a33908c988833479c66cee80700fc57 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 11:33:50 -0800 Subject: [PATCH 04/23] Prepare TraceCtx.name/name_ctr for proxy name generation --- thunder/core/jit_ext.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 6e8f52d45c..1fc26b0817 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -835,6 +835,10 @@ def core_of_forward(*args, **kwargs): from thunder.core.update_aliases import insert_alias_updates + # Copy attributes needed for TensorProxy name construction + trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr + trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names) + alias_tensor_indices = [[i] for i in range(len(trace_of_augmented_fwd.args))] aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd, alias_tensor_indices) @@ -869,6 +873,10 @@ def core_of_forward(*args, **kwargs): ) bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads) + # Copy attributes needed for TensorProxy name construction + bwd_trace_impl.name_ctr = get_jit_ctx().computation_trace.name_ctr + bwd_trace_impl.names = set(get_jit_ctx().computation_trace.names) + alias_tensor_indices = [[i] for i in range(len(bwd_trace_impl.args))] aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl, alias_tensor_indices) @@ -951,6 +959,10 @@ def _generate_random_str_id() -> str: from thunder.core.update_aliases import insert_alias_updates + # Copy attributes needed for TensorProxy name construction + aug_fwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr + aug_fwd_trace.names = set(get_jit_ctx().computation_trace.names) + alias_tensor_indices = [[i] for i in range(len(aug_fwd_trace.args))] aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace, alias_tensor_indices) @@ -988,7 +1000,9 @@ def forward(*args, **kwargs): ] bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols - from thunder.core.update_aliases import insert_alias_updates + # Copy attributes needed for TensorProxy name construction + bwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr + bwd_trace.names = set(get_jit_ctx().computation_trace.names) alias_tensor_indices = [[i] for i in range(len(bwd_trace.args))] aliased_bwd_trace = insert_alias_updates(bwd_trace, alias_tensor_indices) From d5ffd3144ed9920269744288fb1ed2d67a27dd6f Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 11:37:36 -0800 Subject: [PATCH 05/23] Minor fix on test --- thunder/tests/test_inplace_copy.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 7cb3e705fd..d393e79ec7 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -42,16 +42,16 @@ def test_prim_inplace_copy_bwd(executor, device, dtype): def torch_foo(x, y): z = x * y z = z * x - x.copy_(z) + o = x.copy_(z) p = y * y - return p + return p, o def foo(x, y): z = x * y z = z * x - thunder.core.prims.copy_(z, x, grad_enabled=True) + o = thunder.core.prims.copy_(z, x, grad_enabled=True) p = y * y - return p + return p, o traced_nvfuser_foo = executor.make_callable(foo) @@ -72,11 +72,11 @@ def foo(x, y): ) custom_comparator(a, a1) - g = torch.ones_like(thunder_result) - thunder_result.backward(g) + g = torch.ones_like(thunder_result[0]) + thunder_result[0].backward(g) - g1 = torch.ones_like(torch_result) - torch_result.backward(g1) + g1 = torch.ones_like(torch_result[0]) + torch_result[0].backward(g1) assert_close(g, g1) assert_close(b.grad, b1.grad) From 28bc0942f8122590e5ca40dc77c0a76848c602a2 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 13:20:03 -0800 Subject: [PATCH 06/23] Make update_aliases.py handle copy_ --- thunder/core/update_aliases.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 37489cf70a..cc4cf6d799 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -39,6 +39,7 @@ def _get_new_aliases(aliases, trace): def _is_inplace_op(bsym): + # TODO: Handle higher order bsyms containing inplace ops return (bsym.sym.tags and prims.OpTags.IN_PLACE in bsym.sym.tags) or ( bsym.subsymbols and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_ ) @@ -146,7 +147,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li for bsym in computation_trace.bound_symbols: if _is_inplace_op(bsym) or _is_view_creation_op(bsym): # only interested in the input which is modified by the inplace op - in_tensor = variableify(bsym.flat_proxy_args[0]) + mutated_or_aliased_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0 + in_tensor = variableify(bsym.flat_proxy_args[mutated_or_aliased_index]) out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs))) if _is_inplace_op(bsym): inplace_inputs.add(in_tensor) @@ -167,7 +169,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if _is_inplace_op(bsym) or _is_view_creation_op(bsym) or _involves_viewed_args(bsym, viewed): in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))) if _is_inplace_op(bsym) and in_tensors: - in_tensors = {in_tensors[0]} + mutated_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0 + in_tensors = {in_tensors[mutated_index]} else: in_tensors = set(in_tensors) out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs))) From 52119300f6a8d0d217e67ce205b13b8b7366492b Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 14:13:41 -0800 Subject: [PATCH 07/23] Apply update_aliases after decomposition in autodiff --- thunder/__init__.py | 2 +- thunder/transforms/autodiff.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index c0483f8c3a..3bfb797dbf 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -553,7 +553,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com if requires_grad: from thunder.transforms.autodiff import grad_transform_on_trace - computation_trc = grad_transform_on_trace(computation_trc) + computation_trc = grad_transform_on_trace(computation_trc, alias_tensor_indices) computation_traces.append(computation_trc) from thunder.executors.passes import _transform_for_operator_executor_execution diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py index 2d4261cdf6..63935f1a42 100644 --- a/thunder/transforms/autodiff.py +++ b/thunder/transforms/autodiff.py @@ -14,6 +14,7 @@ augmented_forward_impls, backward_impls, ) +from thunder.core.update_aliases import insert_alias_updates import thunder.torch as ltorch @@ -25,7 +26,7 @@ def _should_recompute_bsym_in_backward(bsym): # Transforms a trace by determining which grad transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. -def grad_transform_on_trace(trace, /, *args, **kwargs): +def grad_transform_on_trace(trace, alias_tensor_indices: list[list[int]], /, *args, **kwargs): # This processes the bsyms to map symbols to operator executors: # - in the order of the executor list # - if the executor defines a grad transform, call that to @@ -437,6 +438,9 @@ def process_bsym(self, bsym: BoundSymbol) -> None: joint_trace, _ = AugmentedForwardProcessor(trace)() joint_trace, _ = InsertRecomputationsProcessor(joint_trace)() + # Insert prims.update_aliases before DCE for bsyms exposed by decomposition + joint_trace = insert_alias_updates(joint_trace, alias_tensor_indices) + # run through DCE in case some of the gradients of intermediates are not needed. joint_trace = dce(joint_trace) # group get_grad symbols together for torch compile fusions and to make clear boundary for cse From baa6ede257eccb145ce72ef776694b3d6e95428e Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 15:17:28 -0800 Subject: [PATCH 08/23] Add tests --- thunder/tests/test_inplace_copy.py | 28 +++++++- thunder/tests/test_update_aliases.py | 101 +++++++++++++++++++++------ 2 files changed, 108 insertions(+), 21 deletions(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index d393e79ec7..26bbba45de 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -131,7 +131,7 @@ def func2(x, y): return y, o1, o2 for foo in (func1, func2): - traced_foo = executor.make_callable(foo) + traced_foo = executor.make_callable(foo, skip_inplace_alias_updates=True) tdtype = ttorch.to_torch_dtype(dtype) a = make_tensor((4, 4), device=device, dtype=tdtype) @@ -192,3 +192,29 @@ def fn(x, y): a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=False) b = make_tensor((4, 4), device=device, dtype=torch.float32, requires_grad=False) jitted_fn(a, b) + + +@instantiate(dtypes=(thunder.float32,)) +def test_inplace_on_intermediate(executor, device, dtype): + def f(x): + a = torch.randn_like(x) + b = torch.randn_like(x) + a.copy_(b) + b.sin_() + return (x,) + + def g(x): + a = torch.randn_like(x) + b = torch.randn_like(x) + a.copy_(b) + b.sin_() + return x, a, b + + tdtype = ttorch.to_torch_dtype(dtype) + for fn in [f, g]: + jitted_fn = executor.make_callable(fn) + a = make_tensor((4, 4), device=device, dtype=tdtype) + a_ref = a.clone() + a_out, *_ = jitted_fn(a) + assert_close(a, a_ref) + assert_close(a_out, a_ref) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 4d293aae21..acce56cc64 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -15,11 +15,11 @@ from thunder.tests.make_tensor import make_tensor, make_tensor_like from thunder.tests.framework import ( instantiate, + nvFuserExecutor, ops, NOTHING, TorchExecutor, TorchCompileExecutor, - nvFuserExecutor, requiresCUDA, ) from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place @@ -166,14 +166,15 @@ def g(x, _): @instantiate( dtypes=NOTHING, + decorators=(pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]),), ) -def test_inplace_on_view(executor, device, dtype): +def test_inplace_on_view(executor, device, dtype, inplace_op): def h(x, y): c = torch.exp(x) d = torch.tanh(y) e = c.view(-1) - d.div_(x) + inplace_op(d, x) e += d.flatten() return c, d, e @@ -184,14 +185,14 @@ def i(x, y): e = c.view(-1) e += d.flatten() - d.div_(x) + inplace_op(d, x) return c, d, e def j(x, _): a = x.view(-1) b = x.view(-1) - x.add_(1) + inplace_op(x, x.cos()) aa = a + 1 bb = b + 1 return aa, bb @@ -260,8 +261,10 @@ def h(x): ) def test_chained_inplace(executor, device, dtype): def f(x): - x.add_(1).sin_().mul_(5) - return x + y = x.add_(1) + z = y.sin_() + w = z.mul_(y.copy_(z.cos())) + return w def g(x): x.add_(1).sin().mul_(5) @@ -269,6 +272,7 @@ def g(x): def h(x): x.exp_() + x.copy_(x.tan()) x.sin_() y = x.cos() return y @@ -332,14 +336,16 @@ def g(a, b): ) def test_aliased_input(executor, device, dtype, cache): def f(x, y, z): - return y.exp_().add(x) + z.exp() + s = y.exp_().add(x) + z.exp() + t = x.copy_(z.exp_().view(x.shape)) + z.cos().reshape(x.shape) + return s, t a = make_tensor((2, 1, 2), dtype=torch.float32, device=device) b = a.clone() c = a.view(1, 2, 2) a_ = a.clone().detach() b_ = b.clone().detach() - c_ = c.clone().detach() + c_ = a_.view(1, 2, 2) jfn = executor.make_callable(f, cache=cache) actual = jfn(a, b, c) expected = f(a_, b_, c_) @@ -351,22 +357,25 @@ def f(x, y, z): @instantiate( dtypes=NOTHING, - decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),), + decorators=( + pytest.mark.parametrize("cache", ("constant values", "symbolic values")), + pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]), + ), ) -def test_write_to_intermediate_result(executor, device, dtype, cache): - if executor == nvFuserExecutor: - pytest.xfail("nvFuser does not support writing to intermediate results") - - def fn(x): +def test_write_to_intermediate_result(executor, device, dtype, cache, inplace_op): + def fn(x, z): y = x.view(-1) - y.add_(1) + inplace_op(y, z) return y - a = make_tensor((2, 3), dtype=torch.float32, device=device) + x = make_tensor((2, 3), dtype=torch.float32, device=device) + x_ref = x.clone().detach() + z = make_tensor(6, dtype=torch.float32, device=device) jfn = executor.make_callable(fn, cache=cache) - actual = jfn(a) - expected = fn(a) + actual = jfn(x, z) + expected = fn(x_ref, z) torch.testing.assert_close(actual, expected) + torch.testing.assert_close(x, x_ref) @instantiate( @@ -469,6 +478,58 @@ def f(a): torch.testing.assert_close(out, out_expected) +@instantiate( + dtypes=(dtypes.float32,), +) +def test_batch_norm_update_aliases(executor, device, dtype): + if executor is nvFuserExecutor: + pytest.xfail("update_aliases is not aware of mutation by batch_norm") + + torch_dtype = dtypes.to_torch_dtype(dtype) + num_features = 4 + + def f(x, running_mean, running_var, weight, bias): + out = torch.nn.functional.batch_norm( + x, + running_mean, + running_var, + weight, + bias, + training=True, + momentum=0.1, + eps=1e-5, + ) + return out, x, running_mean.sin(), running_var.cos() + + input_tensor = make_tensor((3, num_features, 5, 5), device=device, dtype=torch_dtype) + running_mean = make_tensor((num_features,), device=device, dtype=torch_dtype) + running_var = make_tensor((num_features,), device=device, dtype=torch_dtype) + weight = make_tensor((num_features,), device=device, dtype=torch_dtype) + bias = make_tensor((num_features,), device=device, dtype=torch_dtype) + + input_ref = input_tensor.clone().detach() + running_mean_ref = running_mean.clone().detach() + running_var_ref = running_var.clone().detach() + weight_ref = weight.clone().detach() + bias_ref = bias.clone().detach() + + jitted_f = executor.make_callable(f) + out_jitted, x_jitted, running_mean_jitted, running_var_jitted = jitted_f( + input_tensor, running_mean, running_var, weight, bias + ) + out_ref, x_ref, running_mean_ref_out, running_var_ref_out = f( + input_ref, running_mean_ref, running_var_ref, weight_ref, bias_ref + ) + + torch.testing.assert_close(out_jitted, out_ref) + torch.testing.assert_close(x_jitted, x_ref) + torch.testing.assert_close(running_mean_jitted, running_mean_ref_out) + torch.testing.assert_close(running_var_jitted, running_var_ref_out) + torch.testing.assert_close(input_tensor, input_ref) + torch.testing.assert_close(running_mean, running_mean_ref) + torch.testing.assert_close(running_var, running_var_ref) + + @instantiate( dtypes=(dtypes.float32,), ) @@ -491,7 +552,7 @@ def backward(ctx, g): return y def foo(x): - return Sin.apply(x) + return Sin.apply(x) * x a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) From 11396be7e561d0e9d534b460b2e8a47f2bb0d084 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 27 Nov 2025 18:04:04 -0800 Subject: [PATCH 09/23] Improve test consistency --- thunder/tests/test_inplace_copy.py | 26 ------------------- thunder/tests/test_update_aliases.py | 39 ++++++++++++++-------------- 2 files changed, 20 insertions(+), 45 deletions(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 26bbba45de..dc962ff9be 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -192,29 +192,3 @@ def fn(x, y): a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=False) b = make_tensor((4, 4), device=device, dtype=torch.float32, requires_grad=False) jitted_fn(a, b) - - -@instantiate(dtypes=(thunder.float32,)) -def test_inplace_on_intermediate(executor, device, dtype): - def f(x): - a = torch.randn_like(x) - b = torch.randn_like(x) - a.copy_(b) - b.sin_() - return (x,) - - def g(x): - a = torch.randn_like(x) - b = torch.randn_like(x) - a.copy_(b) - b.sin_() - return x, a, b - - tdtype = ttorch.to_torch_dtype(dtype) - for fn in [f, g]: - jitted_fn = executor.make_callable(fn) - a = make_tensor((4, 4), device=device, dtype=tdtype) - a_ref = a.clone() - a_out, *_ = jitted_fn(a) - assert_close(a, a_ref) - assert_close(a_out, a_ref) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index acce56cc64..3c401f7bc6 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -168,7 +168,7 @@ def g(x, _): dtypes=NOTHING, decorators=(pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]),), ) -def test_inplace_on_view(executor, device, dtype, inplace_op): +def test_inplace_on_intermediate(executor, device, dtype, inplace_op): def h(x, y): c = torch.exp(x) d = torch.tanh(y) @@ -189,15 +189,7 @@ def i(x, y): return c, d, e - def j(x, _): - a = x.view(-1) - b = x.view(-1) - inplace_op(x, x.cos()) - aa = a + 1 - bb = b + 1 - return aa, bb - - for fn in [h, i, j]: + for fn in [h, i]: a = make_tensor((2, 3), dtype=torch.float32, device=device) b = make_tensor((2, 3), dtype=torch.float32, device=device) a_, b_ = a.clone().detach(), b.clone().detach() @@ -363,19 +355,28 @@ def f(x, y, z): ), ) def test_write_to_intermediate_result(executor, device, dtype, cache, inplace_op): - def fn(x, z): + def f(x, z): y = x.view(-1) inplace_op(y, z) return y - x = make_tensor((2, 3), dtype=torch.float32, device=device) - x_ref = x.clone().detach() - z = make_tensor(6, dtype=torch.float32, device=device) - jfn = executor.make_callable(fn, cache=cache) - actual = jfn(x, z) - expected = fn(x_ref, z) - torch.testing.assert_close(actual, expected) - torch.testing.assert_close(x, x_ref) + def g(x, z): + a = x.view(-1) + b = x.view(-1) + inplace_op(x, z) + aa = a + 1 + bb = b + 1 + return aa, bb + + for fn in [f, g]: + x = make_tensor((2, 3), dtype=torch.float32, device=device) + x_ref = x.clone().detach() + z = make_tensor(6, dtype=torch.float32, device=device) + jfn = executor.make_callable(fn, cache=cache) + actual = jfn(x, z) + expected = fn(x_ref, z) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(x, x_ref) @instantiate( From 406d227ade753d33a39efdadda4d233bd3aadcbb Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 01:25:02 -0800 Subject: [PATCH 10/23] Fix test bug --- thunder/tests/test_update_aliases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 3c401f7bc6..057c83fe21 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -357,7 +357,7 @@ def f(x, y, z): def test_write_to_intermediate_result(executor, device, dtype, cache, inplace_op): def f(x, z): y = x.view(-1) - inplace_op(y, z) + inplace_op(y, z.view(-1)) return y def g(x, z): @@ -371,7 +371,7 @@ def g(x, z): for fn in [f, g]: x = make_tensor((2, 3), dtype=torch.float32, device=device) x_ref = x.clone().detach() - z = make_tensor(6, dtype=torch.float32, device=device) + z = make_tensor((2, 3), dtype=torch.float32, device=device) jfn = executor.make_callable(fn, cache=cache) actual = jfn(x, z) expected = fn(x_ref, z) From 95a6c144899cdc54ecc3a386531a4d20d710969e Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 06:05:15 -0800 Subject: [PATCH 11/23] Add xfail --- thunder/tests/test_update_aliases.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 057c83fe21..132cf42e6c 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -533,8 +533,12 @@ def f(x, running_mean, running_var, weight, bias): @instantiate( dtypes=(dtypes.float32,), + decorators=(pytest.mark.parametrize("requires_grad", (True, False)),), ) -def test_higher_order_inplace_alias_update(executor, device, dtype): +def test_higher_order_inplace_alias_update(executor, device, dtype, requires_grad): + if not requires_grad: + pytest.xfail("update_aliases is not aware of mutation in higher order functions") + torch_dtype = dtypes.to_torch_dtype(dtype) class Sin(torch.autograd.Function): @@ -555,9 +559,9 @@ def backward(ctx, g): def foo(x): return Sin.apply(x) * x - a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) - b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) - c = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) + a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) + b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) + c = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) g = torch.rand_like(a) @@ -575,12 +579,13 @@ def foo(x): torch.testing.assert_close(actual_jit, expected) torch.testing.assert_close(actual_fx, expected) - actual_grad_jit = torch.autograd.grad(actual_jit, a, g) - actual_grad_fx = torch.autograd.grad(actual_fx, b, g) + if requires_grad: + actual_grad_jit = torch.autograd.grad(actual_jit, a, g) + actual_grad_fx = torch.autograd.grad(actual_fx, b, g) - expected_grad = torch.autograd.grad(expected, c, g) - torch.testing.assert_close(actual_grad_fx, expected_grad) - torch.testing.assert_close(actual_grad_jit, expected_grad) + expected_grad = torch.autograd.grad(expected, c, g) + torch.testing.assert_close(actual_grad_fx, expected_grad) + torch.testing.assert_close(actual_grad_jit, expected_grad) @instantiate( From 29a1b6e5d4862cd7e4a7e3f44af890918874577c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 09:06:13 -0800 Subject: [PATCH 12/23] Handle skip_inplace_alias_updates inside insert_alias_updates --- thunder/__init__.py | 9 ++++----- thunder/core/update_aliases.py | 6 +++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 3bfb797dbf..0dcababb64 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -493,11 +493,10 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com [int(i) for i in s.split(",")] for s in alias_tensor_indices_str.split("-") if s != "" ] - if not compile_options.get("skip_inplace_alias_updates", False): - aliased_trace = insert_alias_updates(computation_trc, alias_tensor_indices) - if aliased_trace is not computation_trc: - computation_traces.append(aliased_trace) - computation_trc = computation_traces[-1] + aliased_trace = insert_alias_updates(computation_trc, alias_tensor_indices) + if aliased_trace is not computation_trc: + computation_traces.append(aliased_trace) + computation_trc = computation_traces[-1] cs.last_trace_tracing_stop = time.perf_counter_ns() diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index cc4cf6d799..4c2861a704 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -130,7 +130,11 @@ def replace_args_with_alias_map( return no_implicit_alias_trace, view_groups -def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace: +def insert_alias_updates(computation_trace: Trace) -> Trace: + cd = get_compile_data() + if cd.compile_options.get("skip_inplace_alias_updates", False): + return computation_trace + if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace From daab3bb05862bb2f30c24c810ec4e7c37f48b7f6 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 09:13:51 -0800 Subject: [PATCH 13/23] Access alias_tensor_indices only inside update_aliases --- thunder/__init__.py | 44 ++++------------------------------ thunder/core/jit_ext.py | 12 ++++------ thunder/core/update_aliases.py | 7 +++++- thunder/core/utils.py | 24 +++++++++++++++++++ thunder/transforms/autodiff.py | 4 ++-- 5 files changed, 40 insertions(+), 51 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 0dcababb64..b811ff39a1 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -37,7 +37,6 @@ SHARP_EDGES_OPTIONS, ) from thunder.core.proxies import TensorProxy -from thunder.core.pytree import tree_flatten from thunder.core.recipe import Recipe, Plugin from thunder.core.symbol import has_tags from thunder.core.trace import ( @@ -56,6 +55,7 @@ wrap_return_value_together_with_arguments, ) from thunder.core.update_aliases import insert_alias_updates +from thunder.core.utils import encode_alias_tensor_indices from thunder.executors.torch_autograd import connect_to_autograd import thunder.extend as extend from thunder.extend import Executor, add_default_executor @@ -405,37 +405,6 @@ def jit( cs = CompileStats() weakref_cs = weakref.ref(cs) - def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: - flat_args, _ = tree_flatten((args, kwargs)) - data_ptr_to_tensor_group_index = {} - tensor_group_index_to_tensor_indices = defaultdict(list) - for idx, t in enumerate(flat_args): - # Using type(t) is pytorch.Tensor as TensorSubclasses don't support calling - # data_ptr(). - # Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass) - # - # isinstance(t, pytorch.Tensor) or pytorch.is_tensor(t) will match all Tensor objects including - # subclasses. - if type(t) is pytorch.Tensor and t.layout is pytorch.strided: - data_ptr = t.untyped_storage().data_ptr() - if data_ptr not in data_ptr_to_tensor_group_index: - data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) - tgi = data_ptr_to_tensor_group_index[data_ptr] - tensor_group_index_to_tensor_indices[tgi].append(idx) - return tensor_group_index_to_tensor_indices - - def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str: - """If no aliases found, empty string, otherwise, aliases are comma separated, groups are hyphen separated.""" - - alias_indices = [] - for k, v in _alias_tensor_of_args_kwargs_dict(*args, **kwargs).items(): - if len(v) > 1: - s = ",".join(f"{i}" for i in v) - alias_indices.append(s) - if not alias_indices: - return "" - return "-".join(alias_indices) - def acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor): with compile_data_and_stats(cd, cs): # Acquires the trace OR inlines the trace into an existing trace and @@ -488,12 +457,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc = remove_context_manager_prims_from_trace(computation_trc) computation_traces.append(computation_trc) - alias_tensor_indices_str = cache_info.get("alias_tensor_indices", "") - alias_tensor_indices: list[list[int]] = [ - [int(i) for i in s.split(",")] for s in alias_tensor_indices_str.split("-") if s != "" - ] - - aliased_trace = insert_alias_updates(computation_trc, alias_tensor_indices) + aliased_trace = insert_alias_updates(computation_trc) if aliased_trace is not computation_trc: computation_traces.append(aliased_trace) computation_trc = computation_traces[-1] @@ -552,7 +516,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com if requires_grad: from thunder.transforms.autodiff import grad_transform_on_trace - computation_trc = grad_transform_on_trace(computation_trc, alias_tensor_indices) + computation_trc = grad_transform_on_trace(computation_trc) computation_traces.append(computation_trc) from thunder.executors.passes import _transform_for_operator_executor_execution @@ -673,7 +637,7 @@ def populate_cache_info(cache_info, *args, **kwargs): # It however would require the computation trace to interact with `cache_info`, # which seems to break the consistency of cache_info, leading to a failure in cache_info check. if not compile_options.get("skip_inplace_alias_updates", False): - cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) + cache_info["alias_tensor_indices"] = encode_alias_tensor_indices(*args, **kwargs) # Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform # to treat certain Symbols as constant. diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 1fc26b0817..538b72a710 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -839,8 +839,7 @@ def core_of_forward(*args, **kwargs): trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names) - alias_tensor_indices = [[i] for i in range(len(trace_of_augmented_fwd.args))] - aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd, alias_tensor_indices) + aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd) # Backward definition custom_backward = custom_autograd_function_cls.backward @@ -877,8 +876,7 @@ def core_of_forward(*args, **kwargs): bwd_trace_impl.name_ctr = get_jit_ctx().computation_trace.name_ctr bwd_trace_impl.names = set(get_jit_ctx().computation_trace.names) - alias_tensor_indices = [[i] for i in range(len(bwd_trace_impl.args))] - aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl, alias_tensor_indices) + aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl) @wraps(bwd_trace_impl.python_callable()) def bwd_impl_callable(*args, **kwargs): @@ -963,8 +961,7 @@ def _generate_random_str_id() -> str: aug_fwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr aug_fwd_trace.names = set(get_jit_ctx().computation_trace.names) - alias_tensor_indices = [[i] for i in range(len(aug_fwd_trace.args))] - aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace, alias_tensor_indices) + aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace) trace_of_forward = from_trace(aliased_aug_fwd_trace) for bsym in aug_fwd_trace.bound_symbols: @@ -1004,8 +1001,7 @@ def forward(*args, **kwargs): bwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr bwd_trace.names = set(get_jit_ctx().computation_trace.names) - alias_tensor_indices = [[i] for i in range(len(bwd_trace.args))] - aliased_bwd_trace = insert_alias_updates(bwd_trace, alias_tensor_indices) + aliased_bwd_trace = insert_alias_updates(bwd_trace) @wraps(forward) def grad_transform(*args, **kwargs): diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 4c2861a704..16fe4a7fcd 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -1,11 +1,13 @@ from functools import reduce, partial -from thunder.core.compile_data import using_symbolic_values +import thunder +from thunder.core.compile_data import get_compile_data, using_symbolic_values import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, variableify, unvariableify from thunder.core.pytree import tree_flatten from thunder.core.symbol import BoundSymbol, BoundSymbolTag, has_tags from thunder.core.trace import from_trace, tracectx, TraceCtx as Trace, TraceProvenance, VariableInterface +from thunder.core.utils import parse_alias_tensor_indices def _update_swap_map(swap_map, old_alias, new_alias): @@ -138,6 +140,9 @@ def insert_alias_updates(computation_trace: Trace) -> Trace: if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace + alias_tensor_indices_str = thunder._get_cache_info().get("alias_tensor_indices", "") + alias_tensor_indices = parse_alias_tensor_indices(alias_tensor_indices_str) + swap_map = dict() bsyms = [] diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 40b3e0eb13..679a53229d 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1262,3 +1262,27 @@ def create_python_callable_from_bsym(bsym: BoundSymbolInterface) -> str: prims.python_return(bsym.output) return trace.python(include_decorators=False) + + +def parse_alias_tensor_indices(alias_tensor_indices_str: str) -> list[list[int]]: + return [[int(i) for i in s.split(",")] for s in alias_tensor_indices_str.split("-") if s != ""] + + +def encode_alias_tensor_indices(*args, **kwargs) -> str: + flat_args, _ = tree_flatten((args, kwargs)) + data_ptr_to_tensor_indices = defaultdict(list) + + for idx, t in enumerate(flat_args): + # Using type(t) is torch.Tensor as TensorSubclasses don't support calling data_ptr(). + # Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass) + # + # isinstance(t, torch.Tensor) or torch.is_tensor(t) will match all Tensor objects including subclasses. + if type(t) is torch.Tensor and t.layout is torch.strided: + data_ptr = t.untyped_storage().data_ptr() + data_ptr_to_tensor_indices[data_ptr].append(idx) + + alias_indices = [] + for indices in data_ptr_to_tensor_indices.values(): + if len(indices) > 1: + alias_indices.append(",".join(str(idx) for idx in indices)) + return "-".join(alias_indices) diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py index 63935f1a42..e032c8e928 100644 --- a/thunder/transforms/autodiff.py +++ b/thunder/transforms/autodiff.py @@ -26,7 +26,7 @@ def _should_recompute_bsym_in_backward(bsym): # Transforms a trace by determining which grad transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. -def grad_transform_on_trace(trace, alias_tensor_indices: list[list[int]], /, *args, **kwargs): +def grad_transform_on_trace(trace, /, *args, **kwargs): # This processes the bsyms to map symbols to operator executors: # - in the order of the executor list # - if the executor defines a grad transform, call that to @@ -439,7 +439,7 @@ def process_bsym(self, bsym: BoundSymbol) -> None: joint_trace, _ = InsertRecomputationsProcessor(joint_trace)() # Insert prims.update_aliases before DCE for bsyms exposed by decomposition - joint_trace = insert_alias_updates(joint_trace, alias_tensor_indices) + joint_trace = insert_alias_updates(joint_trace) # run through DCE in case some of the gradients of intermediates are not needed. joint_trace = dce(joint_trace) From d943e3c219fd5e96d1f20e25eb85036e75e94e02 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 09:18:29 -0800 Subject: [PATCH 14/23] Apply update_aliases after first operator ex transform --- thunder/executors/passes.py | 5 +++++ thunder/tests/test_update_aliases.py | 3 --- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 96274a588f..374ea7761e 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -10,6 +10,7 @@ from thunder.core.trace import from_trace, TraceProvenance from thunder.core.trace_interpreter import TraceSubstitutionProcessor from thunder.core.transform_common import dce +from thunder.core.update_aliases import insert_alias_updates from thunder.core.utils import ProxyDict from thunder.executors.pythonex import clear_mutable_collection from thunder.extend import Executor, get_always_executors, OperatorExecutor, FusionExecutor @@ -122,7 +123,11 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) # Step 1 Performs execution transforms # extrace = _transform_for_operator_executor_execution(trace, executors_list) + # Insert alias updates before DCE for bsyms exposed by decomposition + # Inserted prims.update_aliases will be handled in Step 3 + extrace = insert_alias_updates(extrace) extrace = dce(extrace) + # # Step 2 Fusion executors can transform the trace # diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 132cf42e6c..4e5793ebb0 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -536,9 +536,6 @@ def f(x, running_mean, running_var, weight, bias): decorators=(pytest.mark.parametrize("requires_grad", (True, False)),), ) def test_higher_order_inplace_alias_update(executor, device, dtype, requires_grad): - if not requires_grad: - pytest.xfail("update_aliases is not aware of mutation in higher order functions") - torch_dtype = dtypes.to_torch_dtype(dtype) class Sin(torch.autograd.Function): From 837d799b9ffa4c7d7e3a73181dfdea36b993051c Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 09:21:52 -0800 Subject: [PATCH 15/23] Revert meaningless change --- thunder/tests/test_update_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 4e5793ebb0..208c24a89d 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -554,7 +554,7 @@ def backward(ctx, g): return y def foo(x): - return Sin.apply(x) * x + return Sin.apply(x) a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) From 24da7bb83e46709729bffad4b4e53488048bb9eb Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 28 Nov 2025 10:08:49 -0800 Subject: [PATCH 16/23] Subtle fix for notebook test --- thunder/core/update_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 16fe4a7fcd..0137775599 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -134,7 +134,7 @@ def replace_args_with_alias_map( def insert_alias_updates(computation_trace: Trace) -> Trace: cd = get_compile_data() - if cd.compile_options.get("skip_inplace_alias_updates", False): + if cd is not None and cd.compile_options.get("skip_inplace_alias_updates", False): return computation_trace if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): From 8071fba0143defdb3a9affbe16f6405d35f4aad1 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 3 Dec 2025 21:22:10 -0800 Subject: [PATCH 17/23] Reduce cognitive burden --- thunder/core/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 679a53229d..440403356f 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1265,7 +1265,11 @@ def create_python_callable_from_bsym(bsym: BoundSymbolInterface) -> str: def parse_alias_tensor_indices(alias_tensor_indices_str: str) -> list[list[int]]: - return [[int(i) for i in s.split(",")] for s in alias_tensor_indices_str.split("-") if s != ""] + indice_groups = [] + for s in alias_tensor_indices_str.split("-"): + indices = [int(i) for i in s.split(",")] + indice_groups.append(indices) + return indice_groups def encode_alias_tensor_indices(*args, **kwargs) -> str: @@ -1281,8 +1285,9 @@ def encode_alias_tensor_indices(*args, **kwargs) -> str: data_ptr = t.untyped_storage().data_ptr() data_ptr_to_tensor_indices[data_ptr].append(idx) - alias_indices = [] + encoded_indice_groups = [] for indices in data_ptr_to_tensor_indices.values(): if len(indices) > 1: - alias_indices.append(",".join(str(idx) for idx in indices)) - return "-".join(alias_indices) + encoded_indices = ",".join(str(idx) for idx in indices) + encoded_indice_groups.append(encoded_indices) + return "-".join(encoded_indice_groups) From 840a30412cfb87b20e3315e47ac39ef340f5ee98 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 4 Dec 2025 03:15:01 -0800 Subject: [PATCH 18/23] Fixup --- thunder/core/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 440403356f..001a4bcb1d 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1267,6 +1267,8 @@ def create_python_callable_from_bsym(bsym: BoundSymbolInterface) -> str: def parse_alias_tensor_indices(alias_tensor_indices_str: str) -> list[list[int]]: indice_groups = [] for s in alias_tensor_indices_str.split("-"): + if not s: + continue indices = [int(i) for i in s.split(",")] indice_groups.append(indices) return indice_groups From c26e9acabb45e43f06d92b3fd08bfcae7b11e0ce Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Fri, 5 Dec 2025 02:50:03 -0800 Subject: [PATCH 19/23] Add test TODO: make this pass --- thunder/tests/test_update_aliases.py | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 208c24a89d..9576e2fd09 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -531,6 +531,33 @@ def f(x, running_mean, running_var, weight, bias): torch.testing.assert_close(running_var, running_var_ref) +@instantiate( + dtypes=(dtypes.float32,), +) +def test_no_update_aliases_in_backward(executor, device, dtype): + torch_dtype = dtypes.to_torch_dtype(dtype) + + def f(x): + y = x.sin() + y.exp_() + return y + + x = make_tensor((2, 3), device=device, dtype=torch_dtype, requires_grad=True) + jf = executor.make_callable(f) + actual = jf(x) + expected = f(x) + torch.testing.assert_close(actual, expected) + + g = torch.randn_like(actual) + + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + backward_trace = thunder.last_backward_traces(jf)[-1] + assert all(bsym.sym.name != "update_aliases" for bsym in backward_trace.bound_symbols) + + @instantiate( dtypes=(dtypes.float32,), decorators=(pytest.mark.parametrize("requires_grad", (True, False)),), @@ -584,6 +611,9 @@ def foo(x): torch.testing.assert_close(actual_grad_fx, expected_grad) torch.testing.assert_close(actual_grad_jit, expected_grad) + backward_trace = thunder.last_backward_traces(jfoo)[-1] + assert all(bsym.sym.name != "update_aliases" for bsym in backward_trace.bound_symbols) + @instantiate( dtypes=(dtypes.float32,), From 6b0f2f8ac8792de588720f8dc93d843ccd7e0107 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 19:29:09 -0800 Subject: [PATCH 20/23] Revert "Handle skip_inplace_alias_updates inside insert_alias_updates" This reverts commit 29a1b6e5d4862cd7e4a7e3f44af890918874577c. --- thunder/__init__.py | 9 +++++---- thunder/core/update_aliases.py | 6 +----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index b811ff39a1..7cda77012c 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -457,10 +457,11 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc = remove_context_manager_prims_from_trace(computation_trc) computation_traces.append(computation_trc) - aliased_trace = insert_alias_updates(computation_trc) - if aliased_trace is not computation_trc: - computation_traces.append(aliased_trace) - computation_trc = computation_traces[-1] + if not compile_options.get("skip_inplace_alias_updates", False): + aliased_trace = insert_alias_updates(computation_trc) + if aliased_trace is not computation_trc: + computation_traces.append(aliased_trace) + computation_trc = computation_traces[-1] cs.last_trace_tracing_stop = time.perf_counter_ns() diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 1c3079710b..3247fd48ad 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -1,7 +1,7 @@ from functools import reduce, partial import thunder -from thunder.core.compile_data import get_compile_data, using_symbolic_values +from thunder.core.compile_data import using_symbolic_values import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, variableify, unvariableify from thunder.core.pytree import tree_flatten @@ -144,10 +144,6 @@ def _helper(alias): def insert_alias_updates(computation_trace: Trace) -> Trace: - cd = get_compile_data() - if cd is not None and cd.compile_options.get("skip_inplace_alias_updates", False): - return computation_trace - if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace From f8569d003123ac542117840a8f6599b74e4b9517 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 19:31:10 -0800 Subject: [PATCH 21/23] Revert "Access alias_tensor_indices only inside update_aliases" This reverts commit daab3bb05862bb2f30c24c810ec4e7c37f48b7f6. --- thunder/__init__.py | 44 ++++++++++++++++++++++++++++++---- thunder/core/jit_ext.py | 12 ++++++---- thunder/core/update_aliases.py | 7 +----- thunder/core/utils.py | 31 ------------------------ thunder/transforms/autodiff.py | 4 ++-- 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 7cda77012c..3bfb797dbf 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -37,6 +37,7 @@ SHARP_EDGES_OPTIONS, ) from thunder.core.proxies import TensorProxy +from thunder.core.pytree import tree_flatten from thunder.core.recipe import Recipe, Plugin from thunder.core.symbol import has_tags from thunder.core.trace import ( @@ -55,7 +56,6 @@ wrap_return_value_together_with_arguments, ) from thunder.core.update_aliases import insert_alias_updates -from thunder.core.utils import encode_alias_tensor_indices from thunder.executors.torch_autograd import connect_to_autograd import thunder.extend as extend from thunder.extend import Executor, add_default_executor @@ -405,6 +405,37 @@ def jit( cs = CompileStats() weakref_cs = weakref.ref(cs) + def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: + flat_args, _ = tree_flatten((args, kwargs)) + data_ptr_to_tensor_group_index = {} + tensor_group_index_to_tensor_indices = defaultdict(list) + for idx, t in enumerate(flat_args): + # Using type(t) is pytorch.Tensor as TensorSubclasses don't support calling + # data_ptr(). + # Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass) + # + # isinstance(t, pytorch.Tensor) or pytorch.is_tensor(t) will match all Tensor objects including + # subclasses. + if type(t) is pytorch.Tensor and t.layout is pytorch.strided: + data_ptr = t.untyped_storage().data_ptr() + if data_ptr not in data_ptr_to_tensor_group_index: + data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) + tgi = data_ptr_to_tensor_group_index[data_ptr] + tensor_group_index_to_tensor_indices[tgi].append(idx) + return tensor_group_index_to_tensor_indices + + def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str: + """If no aliases found, empty string, otherwise, aliases are comma separated, groups are hyphen separated.""" + + alias_indices = [] + for k, v in _alias_tensor_of_args_kwargs_dict(*args, **kwargs).items(): + if len(v) > 1: + s = ",".join(f"{i}" for i in v) + alias_indices.append(s) + if not alias_indices: + return "" + return "-".join(alias_indices) + def acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor): with compile_data_and_stats(cd, cs): # Acquires the trace OR inlines the trace into an existing trace and @@ -457,8 +488,13 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc = remove_context_manager_prims_from_trace(computation_trc) computation_traces.append(computation_trc) + alias_tensor_indices_str = cache_info.get("alias_tensor_indices", "") + alias_tensor_indices: list[list[int]] = [ + [int(i) for i in s.split(",")] for s in alias_tensor_indices_str.split("-") if s != "" + ] + if not compile_options.get("skip_inplace_alias_updates", False): - aliased_trace = insert_alias_updates(computation_trc) + aliased_trace = insert_alias_updates(computation_trc, alias_tensor_indices) if aliased_trace is not computation_trc: computation_traces.append(aliased_trace) computation_trc = computation_traces[-1] @@ -517,7 +553,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com if requires_grad: from thunder.transforms.autodiff import grad_transform_on_trace - computation_trc = grad_transform_on_trace(computation_trc) + computation_trc = grad_transform_on_trace(computation_trc, alias_tensor_indices) computation_traces.append(computation_trc) from thunder.executors.passes import _transform_for_operator_executor_execution @@ -638,7 +674,7 @@ def populate_cache_info(cache_info, *args, **kwargs): # It however would require the computation trace to interact with `cache_info`, # which seems to break the consistency of cache_info, leading to a failure in cache_info check. if not compile_options.get("skip_inplace_alias_updates", False): - cache_info["alias_tensor_indices"] = encode_alias_tensor_indices(*args, **kwargs) + cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) # Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform # to treat certain Symbols as constant. diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 538b72a710..1fc26b0817 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -839,7 +839,8 @@ def core_of_forward(*args, **kwargs): trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names) - aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd) + alias_tensor_indices = [[i] for i in range(len(trace_of_augmented_fwd.args))] + aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd, alias_tensor_indices) # Backward definition custom_backward = custom_autograd_function_cls.backward @@ -876,7 +877,8 @@ def core_of_forward(*args, **kwargs): bwd_trace_impl.name_ctr = get_jit_ctx().computation_trace.name_ctr bwd_trace_impl.names = set(get_jit_ctx().computation_trace.names) - aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl) + alias_tensor_indices = [[i] for i in range(len(bwd_trace_impl.args))] + aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl, alias_tensor_indices) @wraps(bwd_trace_impl.python_callable()) def bwd_impl_callable(*args, **kwargs): @@ -961,7 +963,8 @@ def _generate_random_str_id() -> str: aug_fwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr aug_fwd_trace.names = set(get_jit_ctx().computation_trace.names) - aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace) + alias_tensor_indices = [[i] for i in range(len(aug_fwd_trace.args))] + aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace, alias_tensor_indices) trace_of_forward = from_trace(aliased_aug_fwd_trace) for bsym in aug_fwd_trace.bound_symbols: @@ -1001,7 +1004,8 @@ def forward(*args, **kwargs): bwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr bwd_trace.names = set(get_jit_ctx().computation_trace.names) - aliased_bwd_trace = insert_alias_updates(bwd_trace) + alias_tensor_indices = [[i] for i in range(len(bwd_trace.args))] + aliased_bwd_trace = insert_alias_updates(bwd_trace, alias_tensor_indices) @wraps(forward) def grad_transform(*args, **kwargs): diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index 3247fd48ad..fb149311b0 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -1,13 +1,11 @@ from functools import reduce, partial -import thunder from thunder.core.compile_data import using_symbolic_values import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, variableify, unvariableify from thunder.core.pytree import tree_flatten from thunder.core.symbol import BoundSymbol, BoundSymbolTag, has_tags from thunder.core.trace import from_trace, tracectx, TraceCtx as Trace, TraceProvenance, VariableInterface -from thunder.core.utils import parse_alias_tensor_indices def _update_swap_map(swap_map, old_alias, new_alias): @@ -143,13 +141,10 @@ def _helper(alias): return list(map(_helper, aliases)) -def insert_alias_updates(computation_trace: Trace) -> Trace: +def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace: if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace - alias_tensor_indices_str = thunder._get_cache_info().get("alias_tensor_indices", "") - alias_tensor_indices = parse_alias_tensor_indices(alias_tensor_indices_str) - swap_map = dict() bsyms = [] diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 001a4bcb1d..40b3e0eb13 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1262,34 +1262,3 @@ def create_python_callable_from_bsym(bsym: BoundSymbolInterface) -> str: prims.python_return(bsym.output) return trace.python(include_decorators=False) - - -def parse_alias_tensor_indices(alias_tensor_indices_str: str) -> list[list[int]]: - indice_groups = [] - for s in alias_tensor_indices_str.split("-"): - if not s: - continue - indices = [int(i) for i in s.split(",")] - indice_groups.append(indices) - return indice_groups - - -def encode_alias_tensor_indices(*args, **kwargs) -> str: - flat_args, _ = tree_flatten((args, kwargs)) - data_ptr_to_tensor_indices = defaultdict(list) - - for idx, t in enumerate(flat_args): - # Using type(t) is torch.Tensor as TensorSubclasses don't support calling data_ptr(). - # Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass) - # - # isinstance(t, torch.Tensor) or torch.is_tensor(t) will match all Tensor objects including subclasses. - if type(t) is torch.Tensor and t.layout is torch.strided: - data_ptr = t.untyped_storage().data_ptr() - data_ptr_to_tensor_indices[data_ptr].append(idx) - - encoded_indice_groups = [] - for indices in data_ptr_to_tensor_indices.values(): - if len(indices) > 1: - encoded_indices = ",".join(str(idx) for idx in indices) - encoded_indice_groups.append(encoded_indices) - return "-".join(encoded_indice_groups) diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py index e032c8e928..63935f1a42 100644 --- a/thunder/transforms/autodiff.py +++ b/thunder/transforms/autodiff.py @@ -26,7 +26,7 @@ def _should_recompute_bsym_in_backward(bsym): # Transforms a trace by determining which grad transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. -def grad_transform_on_trace(trace, /, *args, **kwargs): +def grad_transform_on_trace(trace, alias_tensor_indices: list[list[int]], /, *args, **kwargs): # This processes the bsyms to map symbols to operator executors: # - in the order of the executor list # - if the executor defines a grad transform, call that to @@ -439,7 +439,7 @@ def process_bsym(self, bsym: BoundSymbol) -> None: joint_trace, _ = InsertRecomputationsProcessor(joint_trace)() # Insert prims.update_aliases before DCE for bsyms exposed by decomposition - joint_trace = insert_alias_updates(joint_trace) + joint_trace = insert_alias_updates(joint_trace, alias_tensor_indices) # run through DCE in case some of the gradients of intermediates are not needed. joint_trace = dce(joint_trace) From 55e794ad0380d5ea6b570385098e434b4913a5dd Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 21:04:27 -0800 Subject: [PATCH 22/23] Fixup --- thunder/__init__.py | 2 ++ thunder/common.py | 3 ++- thunder/core/update_aliases.py | 6 +++--- thunder/executors/passes.py | 6 ++++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 3bfb797dbf..dcd1a03585 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -530,6 +530,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com prologue_trc, executors_list=(pythonex,), use_del_last_used=False, + alias_tensor_indices=alias_tensor_indices, ) prologue_trc = prologue_traces[-1] pro = prologue_trc.python_callable(include_decorators=False) @@ -569,6 +570,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc, executors_list=cd.executors_list, use_del_last_used=False, + alias_tensor_indices=alias_tensor_indices, ) computation_trc = extraces[-1] diff --git a/thunder/common.py b/thunder/common.py index 8863c3ec42..02718a2f02 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -644,6 +644,7 @@ def transform_for_execution( *, only_execute_prims=False, use_del_last_used=True, + alias_tensor_indices: list[list[int]] | None = None, ) -> list[TraceCtx]: traces: list[TraceCtx] = [] @@ -652,7 +653,7 @@ def transform_for_execution( # cse_trace = cse(dce_trace) # traces.append(cse_trace) - extrace = executors.passes.transform_for_execution(trace, executors_list) + extrace = executors.passes.transform_for_execution(trace, executors_list, alias_tensor_indices) traces.append(extrace) diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index fb149311b0..4c2425f054 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -68,7 +68,7 @@ def _can_be_reshaped(arg, arg_to_replace): def replace_args_with_alias_map( computation_trace: Trace, - alias_tensor_indices: list[list[int]], + alias_tensor_indices: list[list[int]] | None = None, ) -> tuple[Trace, list[set[VariableInterface]]]: if not alias_tensor_indices: return computation_trace, [] @@ -141,7 +141,7 @@ def _helper(alias): return list(map(_helper, aliases)) -def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace: +def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]] | None = None) -> Trace: if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace @@ -182,7 +182,7 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if ( _is_inplace_op(bsym) or _is_view_creation_op(bsym) - or (bsym.sym.id != prims.PrimIDs.RETURN and _involves_viewed_args(set(unswapped_in_tensors), viewed)) + or _involves_viewed_args(set(unswapped_in_tensors), viewed) ): if _is_inplace_op(bsym) and in_tensors: mutated_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0 diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 374ea7761e..d8d5c48400 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -105,7 +105,9 @@ def process_bsym(self, bsym: BoundSymbol) -> None: return extrace -def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: +def transform_for_execution( + trace: TraceCtx, executors_list: Sequence[Executor], alias_tensor_indices: list[list[int]] | None = None +) -> TraceCtx: import torch start_time_ns = time.perf_counter_ns() @@ -125,7 +127,7 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) extrace = _transform_for_operator_executor_execution(trace, executors_list) # Insert alias updates before DCE for bsyms exposed by decomposition # Inserted prims.update_aliases will be handled in Step 3 - extrace = insert_alias_updates(extrace) + extrace = insert_alias_updates(extrace, alias_tensor_indices) extrace = dce(extrace) # From 8978812aca54afd48061eac300563929c34f4953 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Thu, 11 Dec 2025 22:46:32 -0800 Subject: [PATCH 23/23] Temporarily skip rematerialization --- thunder/executors/nvfuserex_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 8c2d0f6324..6986dcb3af 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -646,7 +646,7 @@ def __init__(self): super().__init__("nvfuser", version=nvfuser.version()) # TODO: Replace this with a query to a compile option - self._use_rematerialization = True + self._use_rematerialization = False fuel_str = os.getenv("NVFUSER_OPTIMIZATION_FUEL") if fuel_str: