Skip to content

Commit 8d41a7a

Browse files
committed
NXP backend: Add test for Linear+BatchNorm fusing
1 parent 0600683 commit 8d41a7a

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import executorch.backends.nxp.tests.models as models
2+
import pytest
3+
import torch
4+
from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import (
5+
AddSimulatedLinearBatchNormFusionQATPass,
6+
)
7+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
8+
FuseBatchNormWithLinearPass,
9+
)
10+
from executorch.backends.nxp.aten_passes.remove_simulated_linear_bn_fusion_qat_pass import (
11+
RemoveSimulatedLinearBatchNormFusionQATPass,
12+
)
13+
14+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
15+
from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec
16+
from torch.export import export
17+
from torch.fx import Node
18+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
19+
20+
21+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
22+
@pytest.mark.parametrize("linear_bias", [True, False])
23+
def test_add_simulated_linear_bn_fusing(input_shape, linear_bias):
24+
random_input = torch.randn(*input_shape)
25+
model = models.LinearBNModule(
26+
input_shape=input_shape,
27+
out_features=5,
28+
linear_bias=linear_bias,
29+
)
30+
model.train()
31+
raw_output = model(random_input)
32+
33+
exported_model = export(model, (random_input,), strict=True)
34+
prepared_model = prepare_qat_pt2e(
35+
exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True)
36+
)
37+
prepared_model = AddSimulatedLinearBatchNormFusionQATPass()(
38+
prepared_model
39+
).graph_module
40+
41+
graph_nodes = list(prepared_model.graph.nodes)
42+
fake_quantize_output = prepared_model(random_input)
43+
44+
expected_number_of_nodes = 23 if linear_bias else 18
45+
linear_node = next(
46+
(
47+
n
48+
for n in graph_nodes
49+
if hasattr(n, "target") and n.target == torch.ops.aten.linear.default
50+
),
51+
None,
52+
)
53+
54+
assert len(graph_nodes) == expected_number_of_nodes
55+
56+
# Assert Linear weight being quantized and "normalized"
57+
assert linear_node is not None
58+
# "activation_post_process" is not a standard operator so we need to check by name.
59+
assert all(n.target.startswith("activation_post_process") for n in linear_node.args)
60+
assert linear_node.args[1].args[0].target == torch.ops.aten.mul.Tensor
61+
62+
# Assert BatchNorm input being "denormalized"
63+
assert graph_nodes[-3].target == torch.ops.aten.batch_norm.default
64+
if linear_bias:
65+
assert graph_nodes[-3].args[0].target == torch.ops.aten.add.Tensor
66+
add_arg_targets = (
67+
n.target for n in graph_nodes[-3].args[0].args if hasattr(n, "target")
68+
)
69+
assert torch.ops.aten.div.Tensor in add_arg_targets
70+
else:
71+
assert graph_nodes[-3].args[0].target == torch.ops.aten.div.Tensor
72+
73+
assert raw_output.shape == fake_quantize_output.shape
74+
75+
76+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
77+
@pytest.mark.parametrize("linear_bias", [True, False])
78+
def test_full_linear_bn_fusing(input_shape, linear_bias):
79+
# TODO: Add pass for quantizing bias node when Linear has bias=False
80+
if not linear_bias:
81+
pytest.skip(
82+
"Linear with bias=False is not yet supported."
83+
"The graph currently produces Linear layer without quantized bias which is incorrect."
84+
)
85+
86+
random_input = torch.randn(*input_shape)
87+
model = models.LinearBNModule(
88+
input_shape=input_shape,
89+
out_features=5,
90+
linear_bias=linear_bias,
91+
)
92+
model.train()
93+
raw_output = model(random_input)
94+
95+
exported_model = export(model, (random_input,), strict=True)
96+
prepared_model = prepare_qat_pt2e(
97+
exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True)
98+
)
99+
100+
prepared_model = AddSimulatedLinearBatchNormFusionQATPass()(
101+
prepared_model
102+
).graph_module
103+
for data in (random_input,):
104+
prepared_model(*data)
105+
prepared_model = RemoveSimulatedLinearBatchNormFusionQATPass()(
106+
prepared_model
107+
).graph_module
108+
prepared_model = FuseBatchNormWithLinearPass()(prepared_model).graph_module
109+
converted_model = convert_pt2e(prepared_model)
110+
111+
quantized_output = converted_model(random_input)
112+
graph_nodes = list(converted_model.graph.nodes)
113+
linear_node = graph_nodes[-4]
114+
115+
def _is_bn(node_: Node) -> bool:
116+
return (
117+
hasattr(node_, "target")
118+
and node_.target == torch.ops.aten.batch_norm.default
119+
)
120+
121+
assert len(graph_nodes) == 11
122+
123+
assert not any(_is_bn(node) for node in graph_nodes)
124+
125+
# Assert linear inputs being quantized
126+
assert linear_node.target == torch.ops.aten.linear.default
127+
assert (
128+
linear_node.args[0].target
129+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
130+
)
131+
assert (
132+
linear_node.args[1].target
133+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
134+
)
135+
136+
# Assert linear outputs being quantized
137+
assert len(linear_node.users) == 1
138+
assert (
139+
list(linear_node.users.keys())[0].target
140+
== torch.ops.quantized_decomposed.quantize_per_tensor.default
141+
)
142+
143+
assert raw_output.shape == quantized_output.shape
144+
145+
146+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
147+
@pytest.mark.parametrize("linear_bias", [True, False])
148+
def test_input_output_graph_equivalence(input_shape, linear_bias):
149+
# TODO: Add pass for quantizing bias node when Linear has bias=False
150+
if not linear_bias:
151+
pytest.skip(
152+
"Linear with bias=False is not yet supported."
153+
"The graph currently produces Linear layer without quantized bias which is incorrect."
154+
)
155+
156+
random_input = torch.randn(*input_shape)
157+
model = models.LinearBNModule(
158+
input_shape=input_shape,
159+
out_features=5,
160+
linear_bias=linear_bias,
161+
)
162+
model.eval()
163+
164+
original_model = export(model, (random_input,), strict=True).module()
165+
166+
processed_model = export(model, (random_input,), strict=True).module()
167+
processed_model = AddSimulatedLinearBatchNormFusionQATPass()(
168+
processed_model
169+
).graph_module
170+
processed_model = RemoveSimulatedLinearBatchNormFusionQATPass()(
171+
processed_model
172+
).graph_module
173+
174+
assert torch.equal(original_model(random_input), processed_model(random_input))
175+
assert len(original_model.graph.nodes) == len(processed_model.graph.nodes)

backends/nxp/tests/models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,41 @@ def forward(self, x):
481481
return self.bn(x)
482482

483483

484+
class LinearBNModule(torch.nn.Module):
485+
def __init__(
486+
self,
487+
input_shape: tuple[int],
488+
out_features: int,
489+
linear_bias: bool,
490+
act: nn.Module | None = None,
491+
):
492+
super().__init__()
493+
494+
self.linear = torch.nn.Linear(
495+
in_features=input_shape[-1], out_features=out_features, bias=linear_bias
496+
)
497+
498+
num_dims = len(input_shape)
499+
bn_features = input_shape[1]
500+
if num_dims == 3:
501+
self.bn = torch.nn.BatchNorm1d(bn_features)
502+
elif num_dims == 4:
503+
self.bn = torch.nn.BatchNorm2d(bn_features)
504+
elif num_dims == 5:
505+
self.bn = torch.nn.BatchNorm3d(bn_features)
506+
else:
507+
raise ValueError(
508+
f"Unknown input_dim: {len(input_shape)}, supported values are 1, 2 or 3."
509+
)
510+
511+
self.act = act
512+
513+
def forward(self, x):
514+
x = self.linear(x)
515+
x = self.bn(x)
516+
return self.act(x) if self.act is not None else x
517+
518+
484519
class MulTensorModule(torch.nn.Module):
485520
def __init__(self):
486521
super().__init__()

0 commit comments

Comments
 (0)