Skip to content

Commit b45a2ba

Browse files
committed
NXP backend: Add test for Linear+BatchNorm fusing
1 parent 34b5982 commit b45a2ba

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.nxp.tests.models as models
7+
import pytest
8+
import torch
9+
from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import (
10+
AddSimulatedLinearBatchNormFusionQATPass,
11+
)
12+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
13+
FuseBatchNormWithLinearPass,
14+
)
15+
from executorch.backends.nxp.aten_passes.remove_simulated_linear_bn_fusion_qat_pass import (
16+
RemoveSimulatedLinearBatchNormFusionQATPass,
17+
)
18+
19+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
20+
from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec
21+
from torch.export import export
22+
from torch.fx import Node
23+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
24+
25+
26+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
27+
@pytest.mark.parametrize("linear_bias", [True, False])
28+
def test_add_simulated_linear_bn_fusing(input_shape, linear_bias):
29+
random_input = torch.randn(*input_shape)
30+
model = models.LinearBNModule(
31+
input_shape=input_shape,
32+
out_features=5,
33+
linear_bias=linear_bias,
34+
)
35+
model.train()
36+
raw_output = model(random_input)
37+
38+
exported_model = export(model, (random_input,), strict=True)
39+
prepared_model = prepare_qat_pt2e(
40+
exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True)
41+
)
42+
prepared_model = AddSimulatedLinearBatchNormFusionQATPass()(
43+
prepared_model
44+
).graph_module
45+
46+
graph_nodes = list(prepared_model.graph.nodes)
47+
fake_quantize_output = prepared_model(random_input)
48+
49+
expected_number_of_nodes = 23 if linear_bias else 18
50+
linear_node = next(
51+
(
52+
n
53+
for n in graph_nodes
54+
if hasattr(n, "target") and n.target == torch.ops.aten.linear.default
55+
),
56+
None,
57+
)
58+
59+
assert len(graph_nodes) == expected_number_of_nodes
60+
61+
# Assert Linear weight being quantized and "normalized"
62+
assert linear_node is not None
63+
# "activation_post_process" is not a standard operator so we need to check by name.
64+
assert all(n.target.startswith("activation_post_process") for n in linear_node.args)
65+
assert linear_node.args[1].args[0].target == torch.ops.aten.mul.Tensor
66+
67+
# Assert BatchNorm input being "denormalized"
68+
assert graph_nodes[-3].target == torch.ops.aten.batch_norm.default
69+
if linear_bias:
70+
assert graph_nodes[-3].args[0].target == torch.ops.aten.add.Tensor
71+
add_arg_targets = (
72+
n.target for n in graph_nodes[-3].args[0].args if hasattr(n, "target")
73+
)
74+
assert torch.ops.aten.div.Tensor in add_arg_targets
75+
else:
76+
assert graph_nodes[-3].args[0].target == torch.ops.aten.div.Tensor
77+
78+
assert raw_output.shape == fake_quantize_output.shape
79+
80+
81+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
82+
@pytest.mark.parametrize("linear_bias", [True, False])
83+
def test_full_linear_bn_fusing(input_shape, linear_bias):
84+
# TODO: Add pass for quantizing bias node when Linear has bias=False
85+
if not linear_bias:
86+
pytest.skip(
87+
"Linear with bias=False is not yet supported."
88+
"The graph currently produces Linear layer without quantized bias which is incorrect."
89+
)
90+
91+
random_input = torch.randn(*input_shape)
92+
model = models.LinearBNModule(
93+
input_shape=input_shape,
94+
out_features=5,
95+
linear_bias=linear_bias,
96+
)
97+
model.train()
98+
raw_output = model(random_input)
99+
100+
exported_model = export(model, (random_input,), strict=True)
101+
prepared_model = prepare_qat_pt2e(
102+
exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True)
103+
)
104+
105+
prepared_model = AddSimulatedLinearBatchNormFusionQATPass()(
106+
prepared_model
107+
).graph_module
108+
for data in (random_input,):
109+
prepared_model(*data)
110+
prepared_model = RemoveSimulatedLinearBatchNormFusionQATPass()(
111+
prepared_model
112+
).graph_module
113+
prepared_model = FuseBatchNormWithLinearPass()(prepared_model).graph_module
114+
converted_model = convert_pt2e(prepared_model)
115+
116+
quantized_output = converted_model(random_input)
117+
graph_nodes = list(converted_model.graph.nodes)
118+
linear_node = graph_nodes[-4]
119+
120+
def _is_bn(node_: Node) -> bool:
121+
return (
122+
hasattr(node_, "target")
123+
and node_.target == torch.ops.aten.batch_norm.default
124+
)
125+
126+
assert len(graph_nodes) == 11
127+
128+
assert not any(_is_bn(node) for node in graph_nodes)
129+
130+
# Assert linear inputs being quantized
131+
assert linear_node.target == torch.ops.aten.linear.default
132+
assert (
133+
linear_node.args[0].target
134+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
135+
)
136+
assert (
137+
linear_node.args[1].target
138+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
139+
)
140+
141+
# Assert linear outputs being quantized
142+
assert len(linear_node.users) == 1
143+
assert (
144+
list(linear_node.users.keys())[0].target
145+
== torch.ops.quantized_decomposed.quantize_per_tensor.default
146+
)
147+
148+
assert raw_output.shape == quantized_output.shape
149+
150+
151+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
152+
@pytest.mark.parametrize("linear_bias", [True, False])
153+
def test_input_output_graph_equivalence(input_shape, linear_bias):
154+
# TODO: Add pass for quantizing bias node when Linear has bias=False
155+
if not linear_bias:
156+
pytest.skip(
157+
"Linear with bias=False is not yet supported."
158+
"The graph currently produces Linear layer without quantized bias which is incorrect."
159+
)
160+
161+
random_input = torch.randn(*input_shape)
162+
model = models.LinearBNModule(
163+
input_shape=input_shape,
164+
out_features=5,
165+
linear_bias=linear_bias,
166+
)
167+
model.eval()
168+
169+
original_model = export(model, (random_input,), strict=True).module()
170+
171+
processed_model = export(model, (random_input,), strict=True).module()
172+
processed_model = AddSimulatedLinearBatchNormFusionQATPass()(
173+
processed_model
174+
).graph_module
175+
processed_model = RemoveSimulatedLinearBatchNormFusionQATPass()(
176+
processed_model
177+
).graph_module
178+
179+
assert torch.equal(original_model(random_input), processed_model(random_input))
180+
assert len(original_model.graph.nodes) == len(processed_model.graph.nodes)

backends/nxp/tests/models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,41 @@ def forward(self, x):
481481
return self.bn(x)
482482

483483

484+
class LinearBNModule(torch.nn.Module):
485+
def __init__(
486+
self,
487+
input_shape: tuple[int],
488+
out_features: int,
489+
linear_bias: bool,
490+
act: nn.Module | None = None,
491+
):
492+
super().__init__()
493+
494+
self.linear = torch.nn.Linear(
495+
in_features=input_shape[-1], out_features=out_features, bias=linear_bias
496+
)
497+
498+
num_dims = len(input_shape)
499+
bn_features = input_shape[1]
500+
if num_dims == 3:
501+
self.bn = torch.nn.BatchNorm1d(bn_features)
502+
elif num_dims == 4:
503+
self.bn = torch.nn.BatchNorm2d(bn_features)
504+
elif num_dims == 5:
505+
self.bn = torch.nn.BatchNorm3d(bn_features)
506+
else:
507+
raise ValueError(
508+
f"Unknown input_dim: {len(input_shape)}, supported values are 1, 2 or 3."
509+
)
510+
511+
self.act = act
512+
513+
def forward(self, x):
514+
x = self.linear(x)
515+
x = self.bn(x)
516+
return self.act(x) if self.act is not None else x
517+
518+
484519
class MulTensorModule(torch.nn.Module):
485520
def __init__(self):
486521
super().__init__()

0 commit comments

Comments
 (0)