Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
efa77e4
Do not skip return stmt in update_aliases.py
shino16 Nov 27, 2025
6a4fbbb
Make copy_ DCE'd by default
shino16 Nov 27, 2025
edac5c5
Tag torchex.copy_ as IN_PLACE instead
shino16 Nov 27, 2025
6fefee2
Prepare TraceCtx.name/name_ctr for proxy name generation
shino16 Nov 27, 2025
d5ffd31
Minor fix on test
shino16 Nov 27, 2025
28bc094
Make update_aliases.py handle copy_
shino16 Nov 27, 2025
5211930
Apply update_aliases after decomposition in autodiff
shino16 Nov 27, 2025
baa6ede
Add tests
shino16 Nov 27, 2025
11396be
Improve test consistency
shino16 Nov 28, 2025
406d227
Fix test bug
shino16 Nov 28, 2025
95a6c14
Add xfail
shino16 Nov 28, 2025
29a1b6e
Handle skip_inplace_alias_updates inside insert_alias_updates
shino16 Nov 28, 2025
daab3bb
Access alias_tensor_indices only inside update_aliases
shino16 Nov 28, 2025
d943e3c
Apply update_aliases after first operator ex transform
shino16 Nov 28, 2025
837d799
Revert meaningless change
shino16 Nov 28, 2025
24da7bb
Subtle fix for notebook test
shino16 Nov 28, 2025
8071fba
Reduce cognitive burden
shino16 Dec 4, 2025
840a304
Fixup
shino16 Dec 4, 2025
c26e9ac
Add test TODO: make this pass
shino16 Dec 5, 2025
1add861
Merge branch 'main' of ssh://github.com/Lightning-AI/lightning-thunde…
shino16 Dec 12, 2025
6b0f2f8
Revert "Handle skip_inplace_alias_updates inside insert_alias_updates"
shino16 Dec 12, 2025
f8569d0
Revert "Access alias_tensor_indices only inside update_aliases"
shino16 Dec 12, 2025
55e794a
Fixup
shino16 Dec 12, 2025
8978812
Temporarily skip rematerialization
shino16 Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
prologue_trc,
executors_list=(pythonex,),
use_del_last_used=False,
alias_tensor_indices=alias_tensor_indices,
)
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable(include_decorators=False)
Expand All @@ -553,7 +554,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
if requires_grad:
from thunder.transforms.autodiff import grad_transform_on_trace

computation_trc = grad_transform_on_trace(computation_trc)
computation_trc = grad_transform_on_trace(computation_trc, alias_tensor_indices)
computation_traces.append(computation_trc)

from thunder.executors.passes import _transform_for_operator_executor_execution
Expand All @@ -569,6 +570,7 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
computation_trc,
executors_list=cd.executors_list,
use_del_last_used=False,
alias_tensor_indices=alias_tensor_indices,
)
computation_trc = extraces[-1]

Expand Down
5 changes: 3 additions & 2 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def wait_for_future(f: FutureTensorProxy) -> TensorProxy:
# TODO Stop calling this here and make it a separate trace in the sequence
# of traces
if use_dce:
trace = dce(trace)
trace = dce(trace, keep_inplace_ops=True)

finally:
# Resets contexts
Expand All @@ -644,6 +644,7 @@ def transform_for_execution(
*,
only_execute_prims=False,
use_del_last_used=True,
alias_tensor_indices: list[list[int]] | None = None,
) -> list[TraceCtx]:
traces: list[TraceCtx] = []

Expand All @@ -652,7 +653,7 @@ def transform_for_execution(
# cse_trace = cse(dce_trace)
# traces.append(cse_trace)

extrace = executors.passes.transform_for_execution(trace, executors_list)
extrace = executors.passes.transform_for_execution(trace, executors_list, alias_tensor_indices)

traces.append(extrace)

Expand Down
16 changes: 15 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,10 @@ def core_of_forward(*args, **kwargs):

from thunder.core.update_aliases import insert_alias_updates

# Copy attributes needed for TensorProxy name construction
trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr
trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names)

alias_tensor_indices = [[i] for i in range(len(trace_of_augmented_fwd.args))]
aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd, alias_tensor_indices)

Expand Down Expand Up @@ -869,6 +873,10 @@ def core_of_forward(*args, **kwargs):
)
bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads)

# Copy attributes needed for TensorProxy name construction
bwd_trace_impl.name_ctr = get_jit_ctx().computation_trace.name_ctr
bwd_trace_impl.names = set(get_jit_ctx().computation_trace.names)

alias_tensor_indices = [[i] for i in range(len(bwd_trace_impl.args))]
aliased_bwd_trace_impl = insert_alias_updates(bwd_trace_impl, alias_tensor_indices)

Expand Down Expand Up @@ -951,6 +959,10 @@ def _generate_random_str_id() -> str:

from thunder.core.update_aliases import insert_alias_updates

# Copy attributes needed for TensorProxy name construction
aug_fwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr
aug_fwd_trace.names = set(get_jit_ctx().computation_trace.names)

alias_tensor_indices = [[i] for i in range(len(aug_fwd_trace.args))]
aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace, alias_tensor_indices)

Expand Down Expand Up @@ -988,7 +1000,9 @@ def forward(*args, **kwargs):
]
bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols

from thunder.core.update_aliases import insert_alias_updates
# Copy attributes needed for TensorProxy name construction
bwd_trace.name_ctr = get_jit_ctx().computation_trace.name_ctr
bwd_trace.names = set(get_jit_ctx().computation_trace.names)

alias_tensor_indices = [[i] for i in range(len(bwd_trace.args))]
aliased_bwd_trace = insert_alias_updates(bwd_trace, alias_tensor_indices)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4333,7 +4333,7 @@ def copy__meta(
return TensorProxy(like=copy_to)


copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.DONT_DCE,))
copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.IN_PLACE,))


def bitcast_meta(
Expand Down
4 changes: 3 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def keep_or_swap(p):
# that only produce non-proxy objects
# NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return
# all the needed proxies of the input trace
def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
def dce(trace: Trace, needed_proxies: None | set[Variable] = None, keep_inplace_ops: bool = False) -> Trace:
start_time_ns = time.perf_counter_ns()

producer_map: ProxyDict = producers(trace)
Expand All @@ -159,6 +159,8 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
# Preserves symbols that should never be collected
if has_tags(bsym, {prims.OpTags.DONT_DCE}):
needed = True
elif keep_inplace_ops and has_tags(bsym, {prims.OpTags.IN_PLACE}):
needed = True
else:
needed = False

Expand Down
13 changes: 8 additions & 5 deletions thunder/core/update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _get_new_aliases(aliases, trace):


def _is_inplace_op(bsym):
# TODO: Handle higher order bsyms containing inplace ops
return (bsym.sym.tags and prims.OpTags.IN_PLACE in bsym.sym.tags) or (
bsym.subsymbols and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_
)
Expand Down Expand Up @@ -67,7 +68,7 @@ def _can_be_reshaped(arg, arg_to_replace):

def replace_args_with_alias_map(
computation_trace: Trace,
alias_tensor_indices: list[list[int]],
alias_tensor_indices: list[list[int]] | None = None,
) -> tuple[Trace, list[set[VariableInterface]]]:
if not alias_tensor_indices:
return computation_trace, []
Expand Down Expand Up @@ -140,7 +141,7 @@ def _helper(alias):
return list(map(_helper, aliases))


def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace:
def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]] | None = None) -> Trace:
if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols):
return computation_trace

Expand All @@ -157,7 +158,8 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
for bsym in computation_trace.bound_symbols:
if _is_inplace_op(bsym) or _is_view_creation_op(bsym):
# only interested in the input which is modified by the inplace op
in_tensor = variableify(bsym.flat_proxy_args[0])
mutated_or_aliased_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0
in_tensor = variableify(bsym.flat_proxy_args[mutated_or_aliased_index])
out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
if _is_inplace_op(bsym):
inplace_inputs.add(in_tensor)
Expand All @@ -180,10 +182,11 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
if (
_is_inplace_op(bsym)
or _is_view_creation_op(bsym)
or (bsym.sym.id != prims.PrimIDs.RETURN and _involves_viewed_args(set(unswapped_in_tensors), viewed))
or _involves_viewed_args(set(unswapped_in_tensors), viewed)
):
if _is_inplace_op(bsym) and in_tensors:
in_tensors = {in_tensors[0]}
mutated_index = 1 if bsym.sym.id == prims.PrimIDs.COPY_ else 0
in_tensors = {in_tensors[mutated_index]}
unswapped_in_tensors = {unswapped_in_tensors[0]}
else:
in_tensors = set(in_tensors)
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def __init__(self):
super().__init__("nvfuser", version=nvfuser.version())

# TODO: Replace this with a query to a compile option
self._use_rematerialization = True
self._use_rematerialization = False

fuel_str = os.getenv("NVFUSER_OPTIMIZATION_FUEL")
if fuel_str:
Expand Down
9 changes: 8 additions & 1 deletion thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from thunder.core.trace import from_trace, TraceProvenance
from thunder.core.trace_interpreter import TraceSubstitutionProcessor
from thunder.core.transform_common import dce
from thunder.core.update_aliases import insert_alias_updates
from thunder.core.utils import ProxyDict
from thunder.executors.pythonex import clear_mutable_collection
from thunder.extend import Executor, get_always_executors, OperatorExecutor, FusionExecutor
Expand Down Expand Up @@ -104,7 +105,9 @@ def process_bsym(self, bsym: BoundSymbol) -> None:
return extrace


def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx:
def transform_for_execution(
trace: TraceCtx, executors_list: Sequence[Executor], alias_tensor_indices: list[list[int]] | None = None
) -> TraceCtx:
import torch

start_time_ns = time.perf_counter_ns()
Expand All @@ -122,7 +125,11 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor])
# Step 1 Performs execution transforms
#
extrace = _transform_for_operator_executor_execution(trace, executors_list)
# Insert alias updates before DCE for bsyms exposed by decomposition
# Inserted prims.update_aliases will be handled in Step 3
extrace = insert_alias_updates(extrace, alias_tensor_indices)
extrace = dce(extrace)

#
# Step 2 Fusion executors can transform the trace
#
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,7 +2371,7 @@ def _copy__impl(copy_from, copy_to, grad_enabled):


copy_ = ex.register_operator(
"copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl, module=torch.Tensor
"copy_", meta=prims.copy_, tags=(prims.OpTags.IN_PLACE,), fn=_copy__impl, module=torch.Tensor
)
_register_implementation(prims.copy_, copy_, checker=_always_executable)

Expand Down
18 changes: 9 additions & 9 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def test_prim_inplace_copy_bwd(executor, device, dtype):
def torch_foo(x, y):
z = x * y
z = z * x
x.copy_(z)
o = x.copy_(z)
p = y * y
return p
return p, o

def foo(x, y):
z = x * y
z = z * x
thunder.core.prims.copy_(z, x, grad_enabled=True)
o = thunder.core.prims.copy_(z, x, grad_enabled=True)
p = y * y
return p
return p, o

traced_nvfuser_foo = executor.make_callable(foo)

Expand All @@ -72,11 +72,11 @@ def foo(x, y):
)
custom_comparator(a, a1)

g = torch.ones_like(thunder_result)
thunder_result.backward(g)
g = torch.ones_like(thunder_result[0])
thunder_result[0].backward(g)

g1 = torch.ones_like(torch_result)
torch_result.backward(g1)
g1 = torch.ones_like(torch_result[0])
torch_result[0].backward(g1)
assert_close(g, g1)
assert_close(b.grad, b1.grad)

Expand Down Expand Up @@ -131,7 +131,7 @@ def func2(x, y):
return y, o1, o2

for foo in (func1, func2):
traced_foo = executor.make_callable(foo)
traced_foo = executor.make_callable(foo, skip_inplace_alias_updates=True)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype)
Expand Down
Loading
Loading