Skip to content

Commit 05da1af

Browse files
authored
Arm backend: Add FP16 support to operators pt.7 (#17330)
Add FP16 support for operators: - sin - sub - tanh - where Update op tests to cover the new datatype. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 647d771 commit 05da1af

File tree

7 files changed

+51
-15
lines changed

7 files changed

+51
-15
lines changed

backends/arm/operators/op_sin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def define_node(
4242
validate_valid_dtype(
4343
self.target,
4444
[*inputs, output],
45-
[ts.DType.FP32, ts.DType.BF16],
45+
[ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
4646
self.tosa_spec,
4747
)
4848
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_sub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def define_node(
3737
validate_valid_dtype(
3838
self.target,
3939
[*inputs, output],
40-
[ts.DType.INT32, ts.DType.FP32, ts.DType.BF16],
40+
[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
4141
self.tosa_spec,
4242
)
4343

backends/arm/operators/op_tanh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[*inputs, output],
47-
[ts.DType.FP32, ts.DType.BF16],
47+
[ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
4848
self.tosa_spec,
4949
)
5050
attr = ts.TosaSerializerAttribute()

backends/arm/test/ops/test_sin.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
"ramp": torch.arange(-16, 16, 0.2),
3131
}
3232

33+
test_data_suite_fp16 = {
34+
"rand_fp16": torch.rand(10, 10, dtype=torch.float16),
35+
}
36+
3337
test_data_suite_bf16 = {
3438
"rand_bf16": torch.rand(3, 3, dtype=torch.bfloat16),
3539
}
@@ -41,7 +45,9 @@ def forward(self, x: torch.Tensor):
4145
return torch.sin(x)
4246

4347

44-
@common.parametrize("test_data", test_data_suite | test_data_suite_bf16)
48+
@common.parametrize(
49+
"test_data", test_data_suite | test_data_suite_fp16 | test_data_suite_bf16
50+
)
4551
def test_sin_tosa_FP(test_data: Tuple):
4652
pipeline = TosaPipelineFP[input_t1](
4753
Sin(),
@@ -88,7 +94,7 @@ def test_sin_u85_INT(test_data: Tuple):
8894
pipeline.run()
8995

9096

91-
@common.parametrize("test_data", test_data_suite)
97+
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
9298
@common.SkipIfNoModelConverter
9399
def test_sin_vgf_no_quant(test_data: Tuple):
94100
pipeline = VgfPipeline[input_t1](

backends/arm/test/ops/test_sub.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
"zeros": lambda: (torch.zeros(10),),
3535
}
3636

37+
sub_test_data_fp16 = {
38+
"rand_2D_fp16": lambda: (torch.rand(4, 4, dtype=torch.float16),),
39+
}
40+
41+
sub_test_data_bf16 = {
42+
"rand_2D_bf16": lambda: (torch.rand(4, 4, dtype=torch.bfloat16),),
43+
}
44+
3745
# Two-input subtraction (x - y)
3846
sub2_test_data = {
3947
"rand_2D_4x4": lambda: (torch.rand(4, 4), torch.rand(4, 4)),
@@ -52,8 +60,11 @@
5260
"rand_3d_Scalar": lambda: (torch.rand(1, 6, 2), 1),
5361
}
5462

55-
sub_test_data_bf16 = {
56-
"rand_2D_bf16": lambda: (torch.rand(4, 4, dtype=torch.bfloat16),),
63+
sub2_test_data_fp16 = {
64+
"rand_2D_pair_fp16": lambda: (
65+
torch.rand(2, 3, dtype=torch.float16),
66+
torch.rand(2, 3, dtype=torch.float16),
67+
),
5768
}
5869

5970
sub2_test_data_bf16 = {
@@ -101,7 +112,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
101112
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
102113

103114

104-
@common.parametrize("test_data", sub_test_data | sub_test_data_bf16)
115+
@common.parametrize(
116+
"test_data", sub_test_data | sub_test_data_fp16 | sub_test_data_bf16
117+
)
105118
def test_sub_tensor_tosa_FP(test_data):
106119
"""Test Subtraction (TOSA FP)"""
107120
pipeline = TosaPipelineFP[input_t1](
@@ -114,7 +127,9 @@ def test_sub_tensor_tosa_FP(test_data):
114127
pipeline.run()
115128

116129

117-
@common.parametrize("test_data", sub2_test_data | sub2_test_data_bf16)
130+
@common.parametrize(
131+
"test_data", sub2_test_data | sub2_test_data_fp16 | sub2_test_data_bf16
132+
)
118133
def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
119134
"""Test Two-Operand Subtraction (TOSA FP)"""
120135
pipeline = TosaPipelineFP[input_t2](
@@ -123,7 +138,9 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
123138
pipeline.run()
124139

125140

126-
@common.parametrize("test_data", sub_tan_test_data | sub2_test_data_bf16)
141+
@common.parametrize(
142+
"test_data", sub_tan_test_data | sub2_test_data_fp16 | sub2_test_data_bf16
143+
)
127144
def test_sub_tensor_tosa_FP_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
128145
"""Test Two-Operand Subtraction with alpha (TOSA FP)"""
129146
pipeline = TosaPipelineFP[input_t2](
@@ -217,7 +234,7 @@ def test_sub_tensor_u85_INT(test_data: Tuple[torch.Tensor, torch.Tensor]):
217234
pipeline.run()
218235

219236

220-
@common.parametrize("test_data", sub_test_data)
237+
@common.parametrize("test_data", sub_test_data | sub_test_data_fp16)
221238
@common.SkipIfNoModelConverter
222239
def test_sub_tensor_vgf_no_quant(test_data: Tuple[torch.Tensor]):
223240
"""Test Subtraction (VGF FP)"""
@@ -231,7 +248,7 @@ def test_sub_tensor_vgf_no_quant(test_data: Tuple[torch.Tensor]):
231248
pipeline.run()
232249

233250

234-
@common.parametrize("test_data", sub2_test_data)
251+
@common.parametrize("test_data", sub2_test_data | sub2_test_data_fp16)
235252
@common.SkipIfNoModelConverter
236253
def test_sub_tensor_vgf_no_quant_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
237254
"""Test Two-Operand Subtraction (VGF FP)"""

backends/arm/test/ops/test_tanh.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
"ramp": lambda: torch.arange(-16, 16, 0.2),
3131
}
3232

33+
test_data_suite_fp16 = {
34+
"rand_fp16": lambda: torch.rand(4, 4, 2, 2, 2, dtype=torch.float16) - 0.5,
35+
}
36+
3337
test_data_suite_bf16 = {
3438
"rand_bf16": lambda: torch.rand(4, 4, 2, 2, 2, dtype=torch.bfloat16) - 0.5,
3539
}
@@ -44,7 +48,9 @@ def forward(self, x):
4448
return self.tanh(x)
4549

4650

47-
@common.parametrize("test_data", test_data_suite | test_data_suite_bf16)
51+
@common.parametrize(
52+
"test_data", test_data_suite | test_data_suite_fp16 | test_data_suite_bf16
53+
)
4854
def test_tanh_tosa_FP(test_data: Tuple):
4955
pipeline = TosaPipelineFP[input_t1](
5056
Tanh(),
@@ -91,7 +97,7 @@ def test_tanh_u85_INT(test_data: Tuple):
9197
pipeline.run()
9298

9399

94-
@common.parametrize("test_data", test_data_suite)
100+
@common.parametrize("test_data", test_data_suite | test_data_suite_fp16)
95101
@common.SkipIfNoModelConverter
96102
def test_tanh_vgf_no_quant(test_data: Tuple):
97103
pipeline = VgfPipeline[input_t1](

backends/arm/test/ops/test_where.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_inputs(self):
4444
self.shape,
4545
dtype=self.dtype[i],
4646
)
47-
elif self.dtype[i] in [torch.float32, torch.bfloat16]:
47+
elif self.dtype[i] in [torch.float16, torch.float32, torch.bfloat16]:
4848
inputs[i] = torch.randn(*self.shape).to(self.dtype[i])
4949
elif self.dtype[i] is torch.bool:
5050
inputs[i] = torch.randint(0, 1, self.shape, dtype=torch.bool)
@@ -114,6 +114,12 @@ def scalar_condition(input: torch.Tensor):
114114
tensor_condition,
115115
)
116116

117+
float16_tensor_cond = Where(
118+
1,
119+
torch.float16,
120+
tensor_condition,
121+
)
122+
117123
float32_tensor_cond_tuple_dtype = Where(
118124
1,
119125
(torch.float32, torch.int8),
@@ -175,6 +181,7 @@ def scalar_condition(input: torch.Tensor):
175181
test_modules_FP = {
176182
**test_modules_common,
177183
"float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool,
184+
"float16_tensor_cond": lambda: float16_tensor_cond,
178185
}
179186

180187
test_modules_FP_bf16 = {

0 commit comments

Comments
 (0)