diff --git a/thunder/__init__.py b/thunder/__init__.py index c0483f8c3a..dcd1a03585 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -530,6 +530,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com prologue_trc, executors_list=(pythonex,), use_del_last_used=False, + alias_tensor_indices=alias_tensor_indices, ) prologue_trc = prologue_traces[-1] pro = prologue_trc.python_callable(include_decorators=False) @@ -553,7 +554,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com if requires_grad: from thunder.transforms.autodiff import grad_transform_on_trace - computation_trc = grad_transform_on_trace(computation_trc) + computation_trc = grad_transform_on_trace(computation_trc, alias_tensor_indices) computation_traces.append(computation_trc) from thunder.executors.passes import _transform_for_operator_executor_execution @@ -569,6 +570,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc, executors_list=cd.executors_list, use_del_last_used=False, + alias_tensor_indices=alias_tensor_indices, ) computation_trc = extraces[-1] diff --git a/thunder/common.py b/thunder/common.py index 487cb28e07..02718a2f02 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -619,7 +619,7 @@ def wait_for_future(f: FutureTensorProxy) -> TensorProxy: # TODO Stop calling this here and make it a separate trace in the sequence # of traces if use_dce: - trace = dce(trace) + trace = dce(trace, keep_inplace_ops=True) finally: # Resets contexts @@ -644,6 +644,7 @@ def transform_for_execution( *, only_execute_prims=False, use_del_last_used=True, + alias_tensor_indices: list[list[int]] | None = None, ) -> list[TraceCtx]: traces: list[TraceCtx] = [] @@ -652,7 +653,7 @@ def transform_for_execution( # cse_trace = cse(dce_trace) # traces.append(cse_trace) - extrace = executors.passes.transform_for_execution(trace, executors_list) + extrace = executors.passes.transform_for_execution(trace, executors_list, alias_tensor_indices) traces.append(extrace) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 6e8f52d45c..1fc26b0817 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -835,6 +835,10 @@ def core_of_forward(*args, **kwargs): from thunder.core.update_aliases import insert_alias_updates + # Copy attributes needed for TensorProxy name construction + trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr + trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names) + alias_tensor_indices = [[i] for i in range(len(trace_of_augmented_fwd.args))] aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd, alias_tensor_indices) @@ -869,6 +873,10 @@ def core_of_forward(*args, **kwargs): ) bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads) + # Copy attributes needed for TensorProxy name construction + bwd_trace_impl.name_ctr = get_jit_ctx().computation_trace.name_ctr + bwd_trace_impl.names = set(get_jit_ctx().computation_trace.names) + alias_tensor_indices = [[i] for i in range(len(bwd_trace_impl.args))] aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl, alias_tensor_indices) @@ -951,6 +959,10 @@ def _generate_random_str_id() -> str: from thunder.core.update_aliases import insert_alias_updates + # Copy attributes needed for TensorProxy name construction + aug_fwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr + aug_fwd_trace.names = set(get_jit_ctx().computation_trace.names) + alias_tensor_indices = [[i] for i in range(len(aug_fwd_trace.args))] aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace, alias_tensor_indices) @@ -988,7 +1000,9 @@ def forward(*args, **kwargs): ] bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols - from thunder.core.update_aliases import insert_alias_updates + # Copy attributes needed for TensorProxy name construction + bwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr + bwd_trace.names = set(get_jit_ctx().computation_trace.names) alias_tensor_indices = [[i] for i in range(len(bwd_trace.args))] aliased_bwd_trace = insert_alias_updates(bwd_trace, alias_tensor_indices) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 90e1f3689b..185b331b37 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4333,7 +4333,7 @@ def copy__meta( return TensorProxy(like=copy_to) -copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.DONT_DCE,)) +copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.IN_PLACE,)) def bitcast_meta( diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bc14a671f2..63dfff56a1 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -142,7 +142,7 @@ def keep_or_swap(p): # that only produce non-proxy objects # NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return # all the needed proxies of the input trace -def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: +def dce(trace: Trace, needed_proxies: None | set[Variable] = None, keep_inplace_ops: bool = False) -> Trace: start_time_ns = time.perf_counter_ns() producer_map: ProxyDict = producers(trace) @@ -159,6 +159,8 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: # Preserves symbols that should never be collected if has_tags(bsym, {prims.OpTags.DONT_DCE}): needed = True + elif keep_inplace_ops and has_tags(bsym, {prims.OpTags.IN_PLACE}): + needed = True else: needed = False diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index c6c7271e5f..4c2425f054 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -39,6 +39,7 @@ def _get_new_aliases(aliases, trace): def _is_inplace_op(bsym): + # TODO: Handle higher order bsyms containing inplace ops return (bsym.sym.tags and prims.OpTags.IN_PLACE in bsym.sym.tags) or ( bsym.subsymbols and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_ ) @@ -67,7 +68,7 @@ def _can_be_reshaped(arg, arg_to_replace): def replace_args_with_alias_map( computation_trace: Trace, - alias_tensor_indices: list[list[int]], + alias_tensor_indices: list[list[int]] | None = None, ) -> tuple[Trace, list[set[VariableInterface]]]: if not alias_tensor_indices: return computation_trace, [] @@ -140,7 +141,7 @@ def _helper(alias): return list(map(_helper, aliases)) -def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace: +def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]] | None = None) -> Trace: if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace @@ -157,7 +158,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li for bsym in computation_trace.bound_symbols: if _is_inplace_op(bsym) or _is_view_creation_op(bsym): # only interested in the input which is modified by the inplace op - in_tensor = variableify(bsym.flat_proxy_args[0]) + mutated_or_aliased_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0 + in_tensor = variableify(bsym.flat_proxy_args[mutated_or_aliased_index]) out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs))) if _is_inplace_op(bsym): inplace_inputs.add(in_tensor) @@ -180,10 +182,11 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if ( _is_inplace_op(bsym) or _is_view_creation_op(bsym) - or (bsym.sym.id != prims.PrimIDs.RETURN and _involves_viewed_args(set(unswapped_in_tensors), viewed)) + or _involves_viewed_args(set(unswapped_in_tensors), viewed) ): if _is_inplace_op(bsym) and in_tensors: - in_tensors = {in_tensors[0]} + mutated_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0 + in_tensors = {in_tensors[mutated_index]} unswapped_in_tensors = {unswapped_in_tensors[0]} else: in_tensors = set(in_tensors) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 8c2d0f6324..6986dcb3af 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -646,7 +646,7 @@ def __init__(self): super().__init__("nvfuser", version=nvfuser.version()) # TODO: Replace this with a query to a compile option - self._use_rematerialization = True + self._use_rematerialization = False fuel_str = os.getenv("NVFUSER_OPTIMIZATION_FUEL") if fuel_str: diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 96274a588f..d8d5c48400 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -10,6 +10,7 @@ from thunder.core.trace import from_trace, TraceProvenance from thunder.core.trace_interpreter import TraceSubstitutionProcessor from thunder.core.transform_common import dce +from thunder.core.update_aliases import insert_alias_updates from thunder.core.utils import ProxyDict from thunder.executors.pythonex import clear_mutable_collection from thunder.extend import Executor, get_always_executors, OperatorExecutor, FusionExecutor @@ -104,7 +105,9 @@ def process_bsym(self, bsym: BoundSymbol) -> None: return extrace -def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: +def transform_for_execution( + trace: TraceCtx, executors_list: Sequence[Executor], alias_tensor_indices: list[list[int]] | None = None +) -> TraceCtx: import torch start_time_ns = time.perf_counter_ns() @@ -122,7 +125,11 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) # Step 1 Performs execution transforms # extrace = _transform_for_operator_executor_execution(trace, executors_list) + # Insert alias updates before DCE for bsyms exposed by decomposition + # Inserted prims.update_aliases will be handled in Step 3 + extrace = insert_alias_updates(extrace, alias_tensor_indices) extrace = dce(extrace) + # # Step 2 Fusion executors can transform the trace # diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index c67aae82ae..6f06104419 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2371,7 +2371,7 @@ def _copy__impl(copy_from, copy_to, grad_enabled): copy_ = ex.register_operator( - "copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl, module=torch.Tensor + "copy_", meta=prims.copy_, tags=(prims.OpTags.IN_PLACE,), fn=_copy__impl, module=torch.Tensor ) _register_implementation(prims.copy_, copy_, checker=_always_executable) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 7cb3e705fd..dc962ff9be 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -42,16 +42,16 @@ def test_prim_inplace_copy_bwd(executor, device, dtype): def torch_foo(x, y): z = x * y z = z * x - x.copy_(z) + o = x.copy_(z) p = y * y - return p + return p, o def foo(x, y): z = x * y z = z * x - thunder.core.prims.copy_(z, x, grad_enabled=True) + o = thunder.core.prims.copy_(z, x, grad_enabled=True) p = y * y - return p + return p, o traced_nvfuser_foo = executor.make_callable(foo) @@ -72,11 +72,11 @@ def foo(x, y): ) custom_comparator(a, a1) - g = torch.ones_like(thunder_result) - thunder_result.backward(g) + g = torch.ones_like(thunder_result[0]) + thunder_result[0].backward(g) - g1 = torch.ones_like(torch_result) - torch_result.backward(g1) + g1 = torch.ones_like(torch_result[0]) + torch_result[0].backward(g1) assert_close(g, g1) assert_close(b.grad, b1.grad) @@ -131,7 +131,7 @@ def func2(x, y): return y, o1, o2 for foo in (func1, func2): - traced_foo = executor.make_callable(foo) + traced_foo = executor.make_callable(foo, skip_inplace_alias_updates=True) tdtype = ttorch.to_torch_dtype(dtype) a = make_tensor((4, 4), device=device, dtype=tdtype) diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 429ad8977d..8bb52ff586 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -15,11 +15,11 @@ from thunder.tests.make_tensor import make_tensor, make_tensor_like from thunder.tests.framework import ( instantiate, + nvFuserExecutor, ops, NOTHING, TorchExecutor, TorchCompileExecutor, - nvFuserExecutor, requiresCUDA, ) from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place @@ -166,14 +166,15 @@ def g(x, _): @instantiate( dtypes=NOTHING, + decorators=(pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]),), ) -def test_inplace_on_view(executor, device, dtype): +def test_inplace_on_intermediate(executor, device, dtype, inplace_op): def h(x, y): c = torch.exp(x) d = torch.tanh(y) e = c.view(-1) - d.div_(x) + inplace_op(d, x) e += d.flatten() return c, d, e @@ -184,7 +185,7 @@ def i(x, y): e = c.view(-1) e += d.flatten() - d.div_(x) + inplace_op(d, x) return c, d, e @@ -264,8 +265,10 @@ def h(x): ) def test_chained_inplace(executor, device, dtype): def f(x): - x.add_(1).sin_().mul_(5) - return x + y = x.add_(1) + z = y.sin_() + w = z.mul_(y.copy_(z.cos())) + return w def g(x): x.add_(1).sin().mul_(5) @@ -273,6 +276,7 @@ def g(x): def h(x): x.exp_() + x.copy_(x.tan()) x.sin_() y = x.cos() return y @@ -336,14 +340,16 @@ def g(a, b): ) def test_aliased_input(executor, device, dtype, cache): def f(x, y, z): - return y.exp_().add(x) + z.exp() + s = y.exp_().add(x) + z.exp() + t = x.copy_(z.exp_().view(x.shape)) + z.cos().reshape(x.shape) + return s, t a = make_tensor((2, 1, 2), dtype=torch.float32, device=device) b = a.clone() c = a.view(1, 2, 2) a_ = a.clone().detach() b_ = b.clone().detach() - c_ = c.clone().detach() + c_ = a_.view(1, 2, 2) jfn = executor.make_callable(f, cache=cache) actual = jfn(a, b, c) expected = f(a_, b_, c_) @@ -355,22 +361,34 @@ def f(x, y, z): @instantiate( dtypes=NOTHING, - decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),), + decorators=( + pytest.mark.parametrize("cache", ("constant values", "symbolic values")), + pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]), + ), ) -def test_write_to_intermediate_result(executor, device, dtype, cache): - if executor == nvFuserExecutor: - pytest.xfail("nvFuser does not support writing to intermediate results") - - def fn(x): +def test_write_to_intermediate_result(executor, device, dtype, cache, inplace_op): + def f(x, z): y = x.view(-1) - y.add_(1) + inplace_op(y, z.view(-1)) return y - a = make_tensor((2, 3), dtype=torch.float32, device=device) - jfn = executor.make_callable(fn, cache=cache) - actual = jfn(a) - expected = fn(a) - torch.testing.assert_close(actual, expected) + def g(x, z): + a = x.view(-1) + b = x.view(-1) + inplace_op(x, z) + aa = a + 1 + bb = b + 1 + return aa, bb + + for fn in [f, g]: + x = make_tensor((2, 3), dtype=torch.float32, device=device) + x_ref = x.clone().detach() + z = make_tensor((2, 3), dtype=torch.float32, device=device) + jfn = executor.make_callable(fn, cache=cache) + actual = jfn(x, z) + expected = fn(x_ref, z) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(x, x_ref) @instantiate( @@ -476,7 +494,87 @@ def f(a): @instantiate( dtypes=(dtypes.float32,), ) -def test_higher_order_inplace_alias_update(executor, device, dtype): +def test_batch_norm_update_aliases(executor, device, dtype): + if executor is nvFuserExecutor: + pytest.xfail("update_aliases is not aware of mutation by batch_norm") + + torch_dtype = dtypes.to_torch_dtype(dtype) + num_features = 4 + + def f(x, running_mean, running_var, weight, bias): + out = torch.nn.functional.batch_norm( + x, + running_mean, + running_var, + weight, + bias, + training=True, + momentum=0.1, + eps=1e-5, + ) + return out, x, running_mean.sin(), running_var.cos() + + input_tensor = make_tensor((3, num_features, 5, 5), device=device, dtype=torch_dtype) + running_mean = make_tensor((num_features,), device=device, dtype=torch_dtype) + running_var = make_tensor((num_features,), device=device, dtype=torch_dtype) + weight = make_tensor((num_features,), device=device, dtype=torch_dtype) + bias = make_tensor((num_features,), device=device, dtype=torch_dtype) + + input_ref = input_tensor.clone().detach() + running_mean_ref = running_mean.clone().detach() + running_var_ref = running_var.clone().detach() + weight_ref = weight.clone().detach() + bias_ref = bias.clone().detach() + + jitted_f = executor.make_callable(f) + out_jitted, x_jitted, running_mean_jitted, running_var_jitted = jitted_f( + input_tensor, running_mean, running_var, weight, bias + ) + out_ref, x_ref, running_mean_ref_out, running_var_ref_out = f( + input_ref, running_mean_ref, running_var_ref, weight_ref, bias_ref + ) + + torch.testing.assert_close(out_jitted, out_ref) + torch.testing.assert_close(x_jitted, x_ref) + torch.testing.assert_close(running_mean_jitted, running_mean_ref_out) + torch.testing.assert_close(running_var_jitted, running_var_ref_out) + torch.testing.assert_close(input_tensor, input_ref) + torch.testing.assert_close(running_mean, running_mean_ref) + torch.testing.assert_close(running_var, running_var_ref) + + +@instantiate( + dtypes=(dtypes.float32,), +) +def test_no_update_aliases_in_backward(executor, device, dtype): + torch_dtype = dtypes.to_torch_dtype(dtype) + + def f(x): + y = x.sin() + y.exp_() + return y + + x = make_tensor((2, 3), device=device, dtype=torch_dtype, requires_grad=True) + jf = executor.make_callable(f) + actual = jf(x) + expected = f(x) + torch.testing.assert_close(actual, expected) + + g = torch.randn_like(actual) + + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + backward_trace = thunder.last_backward_traces(jf)[-1] + assert all(bsym.sym.name != "update_aliases" for bsym in backward_trace.bound_symbols) + + +@instantiate( + dtypes=(dtypes.float32,), + decorators=(pytest.mark.parametrize("requires_grad", (True, False)),), +) +def test_higher_order_inplace_alias_update(executor, device, dtype, requires_grad): torch_dtype = dtypes.to_torch_dtype(dtype) class Sin(torch.autograd.Function): @@ -497,9 +595,9 @@ def backward(ctx, g): def foo(x): return Sin.apply(x) - a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) - b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) - c = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True) + a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) + b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) + c = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=requires_grad) g = torch.rand_like(a) @@ -517,12 +615,16 @@ def foo(x): torch.testing.assert_close(actual_jit, expected) torch.testing.assert_close(actual_fx, expected) - actual_grad_jit = torch.autograd.grad(actual_jit, a, g) - actual_grad_fx = torch.autograd.grad(actual_fx, b, g) + if requires_grad: + actual_grad_jit = torch.autograd.grad(actual_jit, a, g) + actual_grad_fx = torch.autograd.grad(actual_fx, b, g) + + expected_grad = torch.autograd.grad(expected, c, g) + torch.testing.assert_close(actual_grad_fx, expected_grad) + torch.testing.assert_close(actual_grad_jit, expected_grad) - expected_grad = torch.autograd.grad(expected, c, g) - torch.testing.assert_close(actual_grad_fx, expected_grad) - torch.testing.assert_close(actual_grad_jit, expected_grad) + backward_trace = thunder.last_backward_traces(jfoo)[-1] + assert all(bsym.sym.name != "update_aliases" for bsym in backward_trace.bound_symbols) @instantiate( diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py index 2d4261cdf6..63935f1a42 100644 --- a/thunder/transforms/autodiff.py +++ b/thunder/transforms/autodiff.py @@ -14,6 +14,7 @@ augmented_forward_impls, backward_impls, ) +from thunder.core.update_aliases import insert_alias_updates import thunder.torch as ltorch @@ -25,7 +26,7 @@ def _should_recompute_bsym_in_backward(bsym): # Transforms a trace by determining which grad transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. -def grad_transform_on_trace(trace, /, *args, **kwargs): +def grad_transform_on_trace(trace, alias_tensor_indices: list[list[int]], /, *args, **kwargs): # This processes the bsyms to map symbols to operator executors: # - in the order of the executor list # - if the executor defines a grad transform, call that to @@ -437,6 +438,9 @@ def process_bsym(self, bsym: BoundSymbol) -> None: joint_trace, _ = AugmentedForwardProcessor(trace)() joint_trace, _ = InsertRecomputationsProcessor(joint_trace)() + # Insert prims.update_aliases before DCE for bsyms exposed by decomposition + joint_trace = insert_alias_updates(joint_trace, alias_tensor_indices) + # run through DCE in case some of the gradients of intermediates are not needed. joint_trace = dce(joint_trace) # group get_grad symbols together for torch compile fusions and to make clear boundary for cse