3434 "zeros" : lambda : (torch .zeros (10 ),),
3535}
3636
37+ sub_test_data_fp16 = {
38+ "rand_2D_fp16" : lambda : (torch .rand (4 , 4 , dtype = torch .float16 ),),
39+ }
40+
41+ sub_test_data_bf16 = {
42+ "rand_2D_bf16" : lambda : (torch .rand (4 , 4 , dtype = torch .bfloat16 ),),
43+ }
44+
3745# Two-input subtraction (x - y)
3846sub2_test_data = {
3947 "rand_2D_4x4" : lambda : (torch .rand (4 , 4 ), torch .rand (4 , 4 )),
5260 "rand_3d_Scalar" : lambda : (torch .rand (1 , 6 , 2 ), 1 ),
5361}
5462
55- sub_test_data_bf16 = {
56- "rand_2D_bf16" : lambda : (torch .rand (4 , 4 , dtype = torch .bfloat16 ),),
63+ sub2_test_data_fp16 = {
64+ "rand_2D_pair_fp16" : lambda : (
65+ torch .rand (2 , 3 , dtype = torch .float16 ),
66+ torch .rand (2 , 3 , dtype = torch .float16 ),
67+ ),
5768}
5869
5970sub2_test_data_bf16 = {
@@ -101,7 +112,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
101112input_t2 = Tuple [torch .Tensor , torch .Tensor ] # Input x, y
102113
103114
104- @common .parametrize ("test_data" , sub_test_data | sub_test_data_bf16 )
115+ @common .parametrize (
116+ "test_data" , sub_test_data | sub_test_data_fp16 | sub_test_data_bf16
117+ )
105118def test_sub_tensor_tosa_FP (test_data ):
106119 """Test Subtraction (TOSA FP)"""
107120 pipeline = TosaPipelineFP [input_t1 ](
@@ -114,7 +127,9 @@ def test_sub_tensor_tosa_FP(test_data):
114127 pipeline .run ()
115128
116129
117- @common .parametrize ("test_data" , sub2_test_data | sub2_test_data_bf16 )
130+ @common .parametrize (
131+ "test_data" , sub2_test_data | sub2_test_data_fp16 | sub2_test_data_bf16
132+ )
118133def test_sub_tensor_tosa_FP_2 (test_data : Tuple [torch .Tensor , torch .Tensor ]):
119134 """Test Two-Operand Subtraction (TOSA FP)"""
120135 pipeline = TosaPipelineFP [input_t2 ](
@@ -123,7 +138,9 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
123138 pipeline .run ()
124139
125140
126- @common .parametrize ("test_data" , sub_tan_test_data | sub2_test_data_bf16 )
141+ @common .parametrize (
142+ "test_data" , sub_tan_test_data | sub2_test_data_fp16 | sub2_test_data_bf16
143+ )
127144def test_sub_tensor_tosa_FP_alpha (test_data : Tuple [torch .Tensor , torch .Tensor ]):
128145 """Test Two-Operand Subtraction with alpha (TOSA FP)"""
129146 pipeline = TosaPipelineFP [input_t2 ](
@@ -217,7 +234,7 @@ def test_sub_tensor_u85_INT(test_data: Tuple[torch.Tensor, torch.Tensor]):
217234 pipeline .run ()
218235
219236
220- @common .parametrize ("test_data" , sub_test_data )
237+ @common .parametrize ("test_data" , sub_test_data | sub_test_data_fp16 )
221238@common .SkipIfNoModelConverter
222239def test_sub_tensor_vgf_no_quant (test_data : Tuple [torch .Tensor ]):
223240 """Test Subtraction (VGF FP)"""
@@ -231,7 +248,7 @@ def test_sub_tensor_vgf_no_quant(test_data: Tuple[torch.Tensor]):
231248 pipeline .run ()
232249
233250
234- @common .parametrize ("test_data" , sub2_test_data )
251+ @common .parametrize ("test_data" , sub2_test_data | sub2_test_data_fp16 )
235252@common .SkipIfNoModelConverter
236253def test_sub_tensor_vgf_no_quant_2 (test_data : Tuple [torch .Tensor , torch .Tensor ]):
237254 """Test Two-Operand Subtraction (VGF FP)"""
0 commit comments