Skip to content

Commit baa6ede

Browse files
committed
Add tests
1 parent 5211930 commit baa6ede

File tree

2 files changed

+108
-21
lines changed

2 files changed

+108
-21
lines changed

thunder/tests/test_inplace_copy.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def func2(x, y):
131131
return y, o1, o2
132132

133133
for foo in (func1, func2):
134-
traced_foo = executor.make_callable(foo)
134+
traced_foo = executor.make_callable(foo, skip_inplace_alias_updates=True)
135135

136136
tdtype = ttorch.to_torch_dtype(dtype)
137137
a = make_tensor((4, 4), device=device, dtype=tdtype)
@@ -192,3 +192,29 @@ def fn(x, y):
192192
a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=False)
193193
b = make_tensor((4, 4), device=device, dtype=torch.float32, requires_grad=False)
194194
jitted_fn(a, b)
195+
196+
197+
@instantiate(dtypes=(thunder.float32,))
198+
def test_inplace_on_intermediate(executor, device, dtype):
199+
def f(x):
200+
a = torch.randn_like(x)
201+
b = torch.randn_like(x)
202+
a.copy_(b)
203+
b.sin_()
204+
return (x,)
205+
206+
def g(x):
207+
a = torch.randn_like(x)
208+
b = torch.randn_like(x)
209+
a.copy_(b)
210+
b.sin_()
211+
return x, a, b
212+
213+
tdtype = ttorch.to_torch_dtype(dtype)
214+
for fn in [f, g]:
215+
jitted_fn = executor.make_callable(fn)
216+
a = make_tensor((4, 4), device=device, dtype=tdtype)
217+
a_ref = a.clone()
218+
a_out, *_ = jitted_fn(a)
219+
assert_close(a, a_ref)
220+
assert_close(a_out, a_ref)

thunder/tests/test_update_aliases.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from thunder.tests.make_tensor import make_tensor, make_tensor_like
1616
from thunder.tests.framework import (
1717
instantiate,
18+
nvFuserExecutor,
1819
ops,
1920
NOTHING,
2021
TorchExecutor,
2122
TorchCompileExecutor,
22-
nvFuserExecutor,
2323
requiresCUDA,
2424
)
2525
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
@@ -166,14 +166,15 @@ def g(x, _):
166166

167167
@instantiate(
168168
dtypes=NOTHING,
169+
decorators=(pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]),),
169170
)
170-
def test_inplace_on_view(executor, device, dtype):
171+
def test_inplace_on_view(executor, device, dtype, inplace_op):
171172
def h(x, y):
172173
c = torch.exp(x)
173174
d = torch.tanh(y)
174175
e = c.view(-1)
175176

176-
d.div_(x)
177+
inplace_op(d, x)
177178
e += d.flatten()
178179

179180
return c, d, e
@@ -184,14 +185,14 @@ def i(x, y):
184185
e = c.view(-1)
185186

186187
e += d.flatten()
187-
d.div_(x)
188+
inplace_op(d, x)
188189

189190
return c, d, e
190191

191192
def j(x, _):
192193
a = x.view(-1)
193194
b = x.view(-1)
194-
x.add_(1)
195+
inplace_op(x, x.cos())
195196
aa = a + 1
196197
bb = b + 1
197198
return aa, bb
@@ -260,15 +261,18 @@ def h(x):
260261
)
261262
def test_chained_inplace(executor, device, dtype):
262263
def f(x):
263-
x.add_(1).sin_().mul_(5)
264-
return x
264+
y = x.add_(1)
265+
z = y.sin_()
266+
w = z.mul_(y.copy_(z.cos()))
267+
return w
265268

266269
def g(x):
267270
x.add_(1).sin().mul_(5)
268271
return x
269272

270273
def h(x):
271274
x.exp_()
275+
x.copy_(x.tan())
272276
x.sin_()
273277
y = x.cos()
274278
return y
@@ -332,14 +336,16 @@ def g(a, b):
332336
)
333337
def test_aliased_input(executor, device, dtype, cache):
334338
def f(x, y, z):
335-
return y.exp_().add(x) + z.exp()
339+
s = y.exp_().add(x) + z.exp()
340+
t = x.copy_(z.exp_().view(x.shape)) + z.cos().reshape(x.shape)
341+
return s, t
336342

337343
a = make_tensor((2, 1, 2), dtype=torch.float32, device=device)
338344
b = a.clone()
339345
c = a.view(1, 2, 2)
340346
a_ = a.clone().detach()
341347
b_ = b.clone().detach()
342-
c_ = c.clone().detach()
348+
c_ = a_.view(1, 2, 2)
343349
jfn = executor.make_callable(f, cache=cache)
344350
actual = jfn(a, b, c)
345351
expected = f(a_, b_, c_)
@@ -351,22 +357,25 @@ def f(x, y, z):
351357

352358
@instantiate(
353359
dtypes=NOTHING,
354-
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
360+
decorators=(
361+
pytest.mark.parametrize("cache", ("constant values", "symbolic values")),
362+
pytest.mark.parametrize("inplace_op", [torch.Tensor.mul_, torch.Tensor.copy_]),
363+
),
355364
)
356-
def test_write_to_intermediate_result(executor, device, dtype, cache):
357-
if executor == nvFuserExecutor:
358-
pytest.xfail("nvFuser does not support writing to intermediate results")
359-
360-
def fn(x):
365+
def test_write_to_intermediate_result(executor, device, dtype, cache, inplace_op):
366+
def fn(x, z):
361367
y = x.view(-1)
362-
y.add_(1)
368+
inplace_op(y, z)
363369
return y
364370

365-
a = make_tensor((2, 3), dtype=torch.float32, device=device)
371+
x = make_tensor((2, 3), dtype=torch.float32, device=device)
372+
x_ref = x.clone().detach()
373+
z = make_tensor(6, dtype=torch.float32, device=device)
366374
jfn = executor.make_callable(fn, cache=cache)
367-
actual = jfn(a)
368-
expected = fn(a)
375+
actual = jfn(x, z)
376+
expected = fn(x_ref, z)
369377
torch.testing.assert_close(actual, expected)
378+
torch.testing.assert_close(x, x_ref)
370379

371380

372381
@instantiate(
@@ -469,6 +478,58 @@ def f(a):
469478
torch.testing.assert_close(out, out_expected)
470479

471480

481+
@instantiate(
482+
dtypes=(dtypes.float32,),
483+
)
484+
def test_batch_norm_update_aliases(executor, device, dtype):
485+
if executor is nvFuserExecutor:
486+
pytest.xfail("update_aliases is not aware of mutation by batch_norm")
487+
488+
torch_dtype = dtypes.to_torch_dtype(dtype)
489+
num_features = 4
490+
491+
def f(x, running_mean, running_var, weight, bias):
492+
out = torch.nn.functional.batch_norm(
493+
x,
494+
running_mean,
495+
running_var,
496+
weight,
497+
bias,
498+
training=True,
499+
momentum=0.1,
500+
eps=1e-5,
501+
)
502+
return out, x, running_mean.sin(), running_var.cos()
503+
504+
input_tensor = make_tensor((3, num_features, 5, 5), device=device, dtype=torch_dtype)
505+
running_mean = make_tensor((num_features,), device=device, dtype=torch_dtype)
506+
running_var = make_tensor((num_features,), device=device, dtype=torch_dtype)
507+
weight = make_tensor((num_features,), device=device, dtype=torch_dtype)
508+
bias = make_tensor((num_features,), device=device, dtype=torch_dtype)
509+
510+
input_ref = input_tensor.clone().detach()
511+
running_mean_ref = running_mean.clone().detach()
512+
running_var_ref = running_var.clone().detach()
513+
weight_ref = weight.clone().detach()
514+
bias_ref = bias.clone().detach()
515+
516+
jitted_f = executor.make_callable(f)
517+
out_jitted, x_jitted, running_mean_jitted, running_var_jitted = jitted_f(
518+
input_tensor, running_mean, running_var, weight, bias
519+
)
520+
out_ref, x_ref, running_mean_ref_out, running_var_ref_out = f(
521+
input_ref, running_mean_ref, running_var_ref, weight_ref, bias_ref
522+
)
523+
524+
torch.testing.assert_close(out_jitted, out_ref)
525+
torch.testing.assert_close(x_jitted, x_ref)
526+
torch.testing.assert_close(running_mean_jitted, running_mean_ref_out)
527+
torch.testing.assert_close(running_var_jitted, running_var_ref_out)
528+
torch.testing.assert_close(input_tensor, input_ref)
529+
torch.testing.assert_close(running_mean, running_mean_ref)
530+
torch.testing.assert_close(running_var, running_var_ref)
531+
532+
472533
@instantiate(
473534
dtypes=(dtypes.float32,),
474535
)
@@ -491,7 +552,7 @@ def backward(ctx, g):
491552
return y
492553

493554
def foo(x):
494-
return Sin.apply(x)
555+
return Sin.apply(x) * x
495556

496557
a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True)
497558
b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True)

0 commit comments

Comments
 (0)