Skip to content

Commit ff54eb2

Browse files
committed
NXP backend: Add test for Linear+BatchNorm fusing
1 parent cf04cc9 commit ff54eb2

File tree

2 files changed

+224
-0
lines changed

2 files changed

+224
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.nxp.tests.models as models
7+
import pytest
8+
import torch
9+
from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import (
10+
AddSimulatedLinearBatchNormFusionQATPass,
11+
)
12+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
13+
FuseBatchNormWithLinearPass,
14+
)
15+
from executorch.backends.nxp.aten_passes.remove_simulated_linear_bn_fusion_qat_pass import (
16+
RemoveSimulatedLinearBatchNormFusionQATPass,
17+
)
18+
19+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
20+
from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec
21+
from torch.export import export
22+
from torch.fx import Node
23+
from torchao.quantization.pt2e.prepare import _is_activation_post_process_node
24+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
25+
26+
27+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
28+
@pytest.mark.parametrize("linear_bias", [True, False])
29+
def test_add_simulated_linear_bn_fusing(input_shape, linear_bias):
30+
random_input = torch.randn(*input_shape)
31+
model = models.LinearBNModule(
32+
input_shape=input_shape,
33+
out_features=5,
34+
linear_bias=linear_bias,
35+
)
36+
model.train()
37+
raw_output = model(random_input)
38+
39+
exported_model = export(model, (random_input,), strict=True)
40+
prepared_model = prepare_qat_pt2e(
41+
exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True)
42+
)
43+
prepared_model = AddSimulatedLinearBatchNormFusionQATPass()(
44+
prepared_model
45+
).graph_module
46+
47+
graph_nodes = list(prepared_model.graph.nodes)
48+
named_modules = dict(prepared_model.named_modules(remove_duplicate=False))
49+
fake_quantize_output = prepared_model(random_input)
50+
51+
expected_number_of_nodes = 23 if linear_bias else 18
52+
linear_node = next(
53+
(
54+
n
55+
for n in graph_nodes
56+
if hasattr(n, "target") and n.target == torch.ops.aten.linear.default
57+
),
58+
None,
59+
)
60+
61+
assert len(graph_nodes) == expected_number_of_nodes
62+
63+
# Assert Linear weight being quantized and "normalized"
64+
assert linear_node is not None
65+
assert all(
66+
_is_activation_post_process_node(n, named_modules) for n in linear_node.args
67+
)
68+
assert linear_node.args[1].args[0].target == torch.ops.aten.mul.Tensor
69+
70+
# Assert BatchNorm input being "denormalized"
71+
assert graph_nodes[-3].target == torch.ops.aten.batch_norm.default
72+
if linear_bias:
73+
assert graph_nodes[-3].args[0].target == torch.ops.aten.add.Tensor
74+
add_arg_targets = (
75+
n.target for n in graph_nodes[-3].args[0].args if hasattr(n, "target")
76+
)
77+
assert torch.ops.aten.div.Tensor in add_arg_targets
78+
else:
79+
assert graph_nodes[-3].args[0].target == torch.ops.aten.div.Tensor
80+
81+
assert raw_output.shape == fake_quantize_output.shape
82+
83+
84+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
85+
@pytest.mark.parametrize("linear_bias", [True, False])
86+
def test_full_linear_bn_fusing(input_shape, linear_bias):
87+
# TODO: Add pass for quantizing bias node when Linear has bias=False
88+
if not linear_bias:
89+
pytest.skip(
90+
"Linear with bias=False is not yet supported."
91+
"The graph currently produces Linear layer without quantized bias which is incorrect."
92+
)
93+
94+
random_input = torch.randn(*input_shape)
95+
model = models.LinearBNModule(
96+
input_shape=input_shape,
97+
out_features=5,
98+
linear_bias=linear_bias,
99+
)
100+
model.train()
101+
raw_output = model(random_input)
102+
103+
exported_model = export(model, (random_input,), strict=True)
104+
prepared_model = prepare_qat_pt2e(
105+
exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True)
106+
)
107+
108+
prepared_model = AddSimulatedLinearBatchNormFusionQATPass()(
109+
prepared_model
110+
).graph_module
111+
prepared_model(random_input)
112+
prepared_model = RemoveSimulatedLinearBatchNormFusionQATPass()(
113+
prepared_model
114+
).graph_module
115+
prepared_model = FuseBatchNormWithLinearPass()(prepared_model).graph_module
116+
converted_model = convert_pt2e(prepared_model)
117+
118+
quantized_output = converted_model(random_input)
119+
graph_nodes = list(converted_model.graph.nodes)
120+
linear_node = graph_nodes[-4]
121+
122+
def _is_bn(node_: Node) -> bool:
123+
return (
124+
hasattr(node_, "target")
125+
and node_.target == torch.ops.aten.batch_norm.default
126+
)
127+
128+
assert len(graph_nodes) == 11
129+
130+
assert not any(_is_bn(node) for node in graph_nodes)
131+
132+
# Assert linear inputs being quantized
133+
assert linear_node.target == torch.ops.aten.linear.default
134+
assert (
135+
linear_node.args[0].target
136+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
137+
)
138+
assert (
139+
linear_node.args[1].target
140+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
141+
)
142+
143+
# Assert linear outputs being quantized
144+
assert len(linear_node.users) == 1
145+
assert (
146+
list(linear_node.users.keys())[0].target
147+
== torch.ops.quantized_decomposed.quantize_per_tensor.default
148+
)
149+
150+
assert raw_output.shape == quantized_output.shape
151+
152+
153+
@pytest.mark.parametrize("input_shape", [(1, 5, 5), (1, 5, 5, 5)])
154+
@pytest.mark.parametrize("linear_bias", [True, False])
155+
@pytest.mark.parametrize("bn_eps", [1e-5, 1e-6])
156+
def test_input_output_graph_equivalence(input_shape, linear_bias, bn_eps):
157+
# TODO: Add pass for quantizing bias node when Linear has bias=False
158+
if not linear_bias:
159+
pytest.skip(
160+
"Linear with bias=False is not yet supported."
161+
"The graph currently produces Linear layer without quantized bias which is incorrect."
162+
)
163+
164+
random_input = torch.randn(*input_shape)
165+
model = models.LinearBNModule(
166+
input_shape=input_shape,
167+
out_features=5,
168+
linear_bias=linear_bias,
169+
bn_eps=bn_eps,
170+
)
171+
model.eval()
172+
173+
original_model = export(model, (random_input,), strict=True).module()
174+
175+
processed_model = export(model, (random_input,), strict=True).module()
176+
processed_model = AddSimulatedLinearBatchNormFusionQATPass()(
177+
processed_model
178+
).graph_module
179+
180+
assert list(processed_model.graph.nodes)[8].args[1] == bn_eps
181+
182+
processed_model = RemoveSimulatedLinearBatchNormFusionQATPass()(
183+
processed_model
184+
).graph_module
185+
186+
assert list(processed_model.graph.nodes)[-2].args[7] == bn_eps
187+
assert torch.equal(original_model(random_input), processed_model(random_input))
188+
assert len(original_model.graph.nodes) == len(processed_model.graph.nodes)

backends/nxp/tests/models.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,42 @@ def forward(self, x):
481481
return self.bn(x)
482482

483483

484+
class LinearBNModule(torch.nn.Module):
485+
def __init__(
486+
self,
487+
input_shape: tuple[int],
488+
out_features: int,
489+
linear_bias: bool,
490+
bn_eps: float = 1e-5,
491+
act: nn.Module | None = None,
492+
):
493+
super().__init__()
494+
495+
self.linear = torch.nn.Linear(
496+
in_features=input_shape[-1], out_features=out_features, bias=linear_bias
497+
)
498+
499+
num_dims = len(input_shape)
500+
if num_dims == 3:
501+
self.bn = torch.nn.BatchNorm1d(out_features, eps=bn_eps)
502+
elif num_dims == 4:
503+
self.bn = torch.nn.BatchNorm2d(out_features, eps=bn_eps)
504+
elif num_dims == 5:
505+
self.bn = torch.nn.BatchNorm3d(out_features, eps=bn_eps)
506+
else:
507+
raise ValueError(
508+
f"Unsupported dimension {len(input_shape)} of the input_shape "
509+
+ f"({input_shape}). Only 3, 4 and 5 are supported."
510+
)
511+
512+
self.act = act
513+
514+
def forward(self, x):
515+
x = self.linear(x)
516+
x = self.bn(x)
517+
return self.act(x) if self.act is not None else x
518+
519+
484520
class MulTensorModule(torch.nn.Module):
485521
def __init__(self):
486522
super().__init__()

0 commit comments

Comments
 (0)