Skip to content

Commit 6e49e50

Browse files
committed
NXP backend: Add support for fusing in quantized graph
1 parent 69218ba commit 6e49e50

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1111
from torch.nn.parameter import Parameter
1212
from torch.nn.utils import fuse_linear_bn_weights
13+
from torchao.quantization.pt2e.prepare import _is_activation_post_process_node
14+
15+
16+
def _unwrap_if_fq(node: Node, named_modules: dict):
17+
target_node = node
18+
19+
if _is_activation_post_process_node(node, named_modules):
20+
if len(node.args) >= 1:
21+
target_node = node.args[0]
22+
else:
23+
raise ValueError(
24+
f"FakeQuantize node '{node}' should have at least one argument, but has {len(node.args)}."
25+
)
26+
27+
return target_node
1328

1429

1530
class FuseBatchNormWithLinearPass(PassBase):
@@ -76,6 +91,8 @@ def _is_linear(node_: Node):
7691
graph_module, made_changes
7792
) # No batch norm nodes in the model.
7893

94+
named_modules = dict(graph_module.named_modules(remove_duplicate=False))
95+
7996
for node in graph_module.graph.nodes:
8097
if not _is_batch_norm(node):
8198
continue # Not BatchNorm.
@@ -86,11 +103,18 @@ def _is_linear(node_: Node):
86103
continue # Something other than a Linear node comes before the BatchNorm.
87104

88105
linear_node = bn_node.args[0]
89-
linear_weight_node = linear_node.args[1]
90-
linear_bias_node = (
106+
linear_weight_node_or_fq = linear_node.args[1]
107+
linear_bias_node_or_fq = (
91108
linear_node.args[2] if len(linear_node.args) > 2 else None
92109
)
93110

111+
linear_weight_node = _unwrap_if_fq(
112+
linear_weight_node_or_fq, named_modules=named_modules
113+
)
114+
linear_bias_node = _unwrap_if_fq(
115+
linear_bias_node_or_fq, named_modules=named_modules
116+
)
117+
94118
linear_w = self._get_tensor_constant_from_node(
95119
graph_module, linear_weight_node
96120
)

0 commit comments

Comments
 (0)