1515from thunder .tests .make_tensor import make_tensor , make_tensor_like
1616from thunder .tests .framework import (
1717 instantiate ,
18+ nvFuserExecutor ,
1819 ops ,
1920 NOTHING ,
2021 TorchExecutor ,
2122 TorchCompileExecutor ,
22- nvFuserExecutor ,
2323 requiresCUDA ,
2424)
2525from thunder .torch import _torch_to_thunder_function_map , _inplace_to_out_of_place
@@ -166,14 +166,15 @@ def g(x, _):
166166
167167@instantiate (
168168 dtypes = NOTHING ,
169+ decorators = (pytest .mark .parametrize ("inplace_op" , [torch .Tensor .mul_ , torch .Tensor .copy_ ]),),
169170)
170- def test_inplace_on_view (executor , device , dtype ):
171+ def test_inplace_on_view (executor , device , dtype , inplace_op ):
171172 def h (x , y ):
172173 c = torch .exp (x )
173174 d = torch .tanh (y )
174175 e = c .view (- 1 )
175176
176- d . div_ ( x )
177+ inplace_op ( d , x )
177178 e += d .flatten ()
178179
179180 return c , d , e
@@ -184,14 +185,14 @@ def i(x, y):
184185 e = c .view (- 1 )
185186
186187 e += d .flatten ()
187- d . div_ ( x )
188+ inplace_op ( d , x )
188189
189190 return c , d , e
190191
191192 def j (x , _ ):
192193 a = x .view (- 1 )
193194 b = x .view (- 1 )
194- x . add_ ( 1 )
195+ inplace_op ( x , x . cos () )
195196 aa = a + 1
196197 bb = b + 1
197198 return aa , bb
@@ -260,15 +261,18 @@ def h(x):
260261)
261262def test_chained_inplace (executor , device , dtype ):
262263 def f (x ):
263- x .add_ (1 ).sin_ ().mul_ (5 )
264- return x
264+ y = x .add_ (1 )
265+ z = y .sin_ ()
266+ w = z .mul_ (y .copy_ (z .cos ()))
267+ return w
265268
266269 def g (x ):
267270 x .add_ (1 ).sin ().mul_ (5 )
268271 return x
269272
270273 def h (x ):
271274 x .exp_ ()
275+ x .copy_ (x .tan ())
272276 x .sin_ ()
273277 y = x .cos ()
274278 return y
@@ -332,14 +336,16 @@ def g(a, b):
332336)
333337def test_aliased_input (executor , device , dtype , cache ):
334338 def f (x , y , z ):
335- return y .exp_ ().add (x ) + z .exp ()
339+ s = y .exp_ ().add (x ) + z .exp ()
340+ t = x .copy_ (z .exp_ ().view (x .shape )) + z .cos ().reshape (x .shape )
341+ return s , t
336342
337343 a = make_tensor ((2 , 1 , 2 ), dtype = torch .float32 , device = device )
338344 b = a .clone ()
339345 c = a .view (1 , 2 , 2 )
340346 a_ = a .clone ().detach ()
341347 b_ = b .clone ().detach ()
342- c_ = c . clone (). detach ( )
348+ c_ = a_ . view ( 1 , 2 , 2 )
343349 jfn = executor .make_callable (f , cache = cache )
344350 actual = jfn (a , b , c )
345351 expected = f (a_ , b_ , c_ )
@@ -351,22 +357,25 @@ def f(x, y, z):
351357
352358@instantiate (
353359 dtypes = NOTHING ,
354- decorators = (pytest .mark .parametrize ("cache" , ("constant values" , "symbolic values" )),),
360+ decorators = (
361+ pytest .mark .parametrize ("cache" , ("constant values" , "symbolic values" )),
362+ pytest .mark .parametrize ("inplace_op" , [torch .Tensor .mul_ , torch .Tensor .copy_ ]),
363+ ),
355364)
356- def test_write_to_intermediate_result (executor , device , dtype , cache ):
357- if executor == nvFuserExecutor :
358- pytest .xfail ("nvFuser does not support writing to intermediate results" )
359-
360- def fn (x ):
365+ def test_write_to_intermediate_result (executor , device , dtype , cache , inplace_op ):
366+ def fn (x , z ):
361367 y = x .view (- 1 )
362- y . add_ ( 1 )
368+ inplace_op ( y , z )
363369 return y
364370
365- a = make_tensor ((2 , 3 ), dtype = torch .float32 , device = device )
371+ x = make_tensor ((2 , 3 ), dtype = torch .float32 , device = device )
372+ x_ref = x .clone ().detach ()
373+ z = make_tensor (6 , dtype = torch .float32 , device = device )
366374 jfn = executor .make_callable (fn , cache = cache )
367- actual = jfn (a )
368- expected = fn (a )
375+ actual = jfn (x , z )
376+ expected = fn (x_ref , z )
369377 torch .testing .assert_close (actual , expected )
378+ torch .testing .assert_close (x , x_ref )
370379
371380
372381@instantiate (
@@ -469,6 +478,58 @@ def f(a):
469478 torch .testing .assert_close (out , out_expected )
470479
471480
481+ @instantiate (
482+ dtypes = (dtypes .float32 ,),
483+ )
484+ def test_batch_norm_update_aliases (executor , device , dtype ):
485+ if executor is nvFuserExecutor :
486+ pytest .xfail ("update_aliases is not aware of mutation by batch_norm" )
487+
488+ torch_dtype = dtypes .to_torch_dtype (dtype )
489+ num_features = 4
490+
491+ def f (x , running_mean , running_var , weight , bias ):
492+ out = torch .nn .functional .batch_norm (
493+ x ,
494+ running_mean ,
495+ running_var ,
496+ weight ,
497+ bias ,
498+ training = True ,
499+ momentum = 0.1 ,
500+ eps = 1e-5 ,
501+ )
502+ return out , x , running_mean .sin (), running_var .cos ()
503+
504+ input_tensor = make_tensor ((3 , num_features , 5 , 5 ), device = device , dtype = torch_dtype )
505+ running_mean = make_tensor ((num_features ,), device = device , dtype = torch_dtype )
506+ running_var = make_tensor ((num_features ,), device = device , dtype = torch_dtype )
507+ weight = make_tensor ((num_features ,), device = device , dtype = torch_dtype )
508+ bias = make_tensor ((num_features ,), device = device , dtype = torch_dtype )
509+
510+ input_ref = input_tensor .clone ().detach ()
511+ running_mean_ref = running_mean .clone ().detach ()
512+ running_var_ref = running_var .clone ().detach ()
513+ weight_ref = weight .clone ().detach ()
514+ bias_ref = bias .clone ().detach ()
515+
516+ jitted_f = executor .make_callable (f )
517+ out_jitted , x_jitted , running_mean_jitted , running_var_jitted = jitted_f (
518+ input_tensor , running_mean , running_var , weight , bias
519+ )
520+ out_ref , x_ref , running_mean_ref_out , running_var_ref_out = f (
521+ input_ref , running_mean_ref , running_var_ref , weight_ref , bias_ref
522+ )
523+
524+ torch .testing .assert_close (out_jitted , out_ref )
525+ torch .testing .assert_close (x_jitted , x_ref )
526+ torch .testing .assert_close (running_mean_jitted , running_mean_ref_out )
527+ torch .testing .assert_close (running_var_jitted , running_var_ref_out )
528+ torch .testing .assert_close (input_tensor , input_ref )
529+ torch .testing .assert_close (running_mean , running_mean_ref )
530+ torch .testing .assert_close (running_var , running_var_ref )
531+
532+
472533@instantiate (
473534 dtypes = (dtypes .float32 ,),
474535)
@@ -491,7 +552,7 @@ def backward(ctx, g):
491552 return y
492553
493554 def foo (x ):
494- return Sin .apply (x )
555+ return Sin .apply (x ) * x
495556
496557 a = torch .ones (2 , device = device , dtype = torch_dtype , requires_grad = True )
497558 b = torch .ones (2 , device = device , dtype = torch_dtype , requires_grad = True )
0 commit comments