Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions thunder/core/update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols):
return computation_trace

swap_map = dict()
bsyms = []

# First pass: identify inputs which are views of each other and swap them out with a default,
# reshaping if necessary.
computation_trace, view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices)
Expand All @@ -173,10 +170,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0]
viewed = set(reduce(set.union, view_groups, set()))

swap_map = dict()
swap_map_by_update_aliases = dict()
bsyms = []

# Third pass: insert alias updates
for bsym in computation_trace.bound_symbols:
bsym = bsym.from_bsym_swap_proxies(swap_map)
in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args)))
unswapped_in_tensors = _unswap(swap_map, in_tensors)
# We do not unswap out_tensor of an inplace bsym into in_tensor, because functional dependency is already
# captured by that reference to out_tensor
unswapped_in_tensors = _unswap(swap_map_by_update_aliases, in_tensors)
if (
_is_inplace_op(bsym)
or _is_view_creation_op(bsym)
Expand All @@ -189,10 +193,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
in_tensors = set(in_tensors)
out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
encountered.update(in_tensors)
group = set().union(*filter(lambda g: g.intersection(unswapped_in_tensors), view_groups))
if not group or not (views_encountered := group.intersection(encountered)):
# If group is empty, this is a view creation with operands that are not involved in any inplace ops.
bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True))
involved_view_groups = [g for g in view_groups if g.intersection(unswapped_in_tensors)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: wouldn't this call g.intersect len(view_groups) times?

involved_views = set().union(*involved_view_groups)
views_encountered = tuple(involved_views.intersection(encountered))

if _is_inplace_op(bsym):
# This is a hack to insert fusion break because nvFuser doesn't support mutation on intermediates
views_encountered = tuple(unswapped_in_tensors.union(views_encountered))

if not views_encountered:
# This is a view creation with operands that are not involved in any inplace ops.
bsyms.append(bsym)
continue

new_aliases = _get_new_aliases(views_encountered, computation_trace)
Expand All @@ -202,14 +213,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
if has_tags(bsym, {BoundSymbolTag.BACKWARD}):
update_bsym.tags.add(BoundSymbolTag.BACKWARD)
bsyms.append(update_bsym)
encountered.update(out_tensors)
encountered.update(out_tensors, map(variableify, new_aliases))
bsyms.append(new_bsym)
if _is_inplace_op(bsym) and len(out_tensors) == 1 and len(in_tensors) == 1:
# This relies on these being one element sets (ltorch.setitem_ yields no outs).
swap_map = _update_swap_map(swap_map, in_tensors.pop(), unvariableify(out_tensors.pop()))

for alias, new_alias in zip(views_encountered, new_aliases):
_update_swap_map(swap_map_by_update_aliases, alias, new_alias)

else:
bsyms.append(bsym.from_bsym_swap_proxies(swap_map))
bsyms.append(bsym)

alias_updated_trace = from_trace(computation_trace)
alias_updated_trace.set_provenance(TraceProvenance("Update aliases for in-place ops"))
Expand Down
11 changes: 8 additions & 3 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,10 +817,15 @@ def map_redundant(x: Any) -> Any:
new_symbols = [new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols]
cse_trace.bound_symbols = list(filterfalse(lambda a: a is None, new_symbols))

return_bsym = cse_trace.bound_symbols[-1]
assert return_bsym.sym.id == prims.PrimIDs.RETURN
return_bsym = None
for idx, bsym in enumerate(cse_trace.bound_symbols):
if bsym.sym.id == prims.PrimIDs.RETURN:
return_bsym = cse_trace.bound_symbols.pop(idx)
break
assert return_bsym is not None
Comment on lines 821 to 828
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the code changes first before looking at the discussion and this was very alarming to me. Could you add a TODO comment about this being removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I appreciate that kind of feedback. Yes, it's a rough solution indeed...


trace_output = tree_map(map_redundant, return_bsym.args)
cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None)
cse_trace.bound_symbols.append(prims.python_return.bind(*trace_output, output=None))

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
58 changes: 54 additions & 4 deletions thunder/tests/test_update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
NOTHING,
TorchExecutor,
TorchCompileExecutor,
nvFuserExecutor,
requiresCUDA,
)
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
Expand Down Expand Up @@ -358,9 +357,6 @@ def f(x, y, z):
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
)
def test_write_to_intermediate_result(executor, device, dtype, cache):
if executor == nvFuserExecutor:
pytest.xfail("nvFuser does not support writing to intermediate results")

def fn(x):
y = x.view(-1)
y.add_(1)
Expand All @@ -373,6 +369,24 @@ def fn(x):
torch.testing.assert_close(actual, expected)


@instantiate(
dtypes=NOTHING,
decorators=(pytest.mark.parametrize("requires_grad", (False, True)),),
)
def test_write_to_viewed_intermediate(executor, device, dtype, requires_grad):
def fn(a):
b = a * 2
c = b[:]
c.tanh_()
return a * b

a = make_tensor((2, 3), dtype=torch.float32, device=device, requires_grad=requires_grad)
jfn = executor.make_callable(fn, fusion_type="dataflow")
actual = jfn(a)
expected = fn(a)
torch.testing.assert_close(actual, expected)


@instantiate(
dtypes=(dtypes.float32,),
)
Expand Down Expand Up @@ -546,3 +560,39 @@ def f(x, y, z):
torch.testing.assert_close(a, a_)
torch.testing.assert_close(b, b_)
torch.testing.assert_close(c, c_)


@instantiate(
dtypes=(dtypes.float32,),
)
def test_update_aliases_count(executor, device, dtype):
def f(x):
x.sin_()
return x * x * x * x

def g(x):
x.sin_()
x.cos_()
return x * x * x * x

def h(x):
y = x[:]
y.sin_()
return x * x * x * x

expected_num_update_aliases = {
f: 1, # before sin_
g: 2, # before sin_ and cos_; latter is a hack to cause fusion break
h: 5, # before sin_ and every mul
}

for fn in [f, g]:
a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device)
a_ = a.clone().detach()
jfn = executor.make_callable(fn)
actual = jfn(a)
expected = fn(a_)
torch.testing.assert_close(actual, expected)
extrace = thunder.last_traces(jfn)[-1]
actual_num_update_aliases = len([bsym for bsym in extrace.bound_symbols if bsym.sym.name == "update_aliases"])
assert actual_num_update_aliases == expected_num_update_aliases[fn]
Loading