Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.nn.parameter import Parameter
from torch.nn.utils import fuse_linear_bn_weights
from torchao.quantization.pt2e.prepare import _is_activation_post_process_node


def _unwrap_if_fq(node: Node, named_modules: dict):
target_node = node

if _is_activation_post_process_node(node, named_modules):
if len(node.args) >= 1:
target_node = node.args[0]
else:
raise ValueError(
f"FakeQuantize node '{node}' should have at least one argument, but has {len(node.args)}."
)

return target_node


class FuseBatchNormWithLinearPass(PassBase):
Expand Down Expand Up @@ -76,6 +91,8 @@ def _is_linear(node_: Node):
graph_module, made_changes
) # No batch norm nodes in the model.

named_modules = dict(graph_module.named_modules(remove_duplicate=False))

for node in graph_module.graph.nodes:
if not _is_batch_norm(node):
continue # Not BatchNorm.
Expand All @@ -86,11 +103,18 @@ def _is_linear(node_: Node):
continue # Something other than a Linear node comes before the BatchNorm.

linear_node = bn_node.args[0]
linear_weight_node = linear_node.args[1]
linear_bias_node = (
linear_weight_node_or_fq = linear_node.args[1]
linear_bias_node_or_fq = (
linear_node.args[2] if len(linear_node.args) > 2 else None
)

linear_weight_node = _unwrap_if_fq(
linear_weight_node_or_fq, named_modules=named_modules
)
linear_bias_node = _unwrap_if_fq(
linear_bias_node_or_fq, named_modules=named_modules
)

linear_w = self._get_tensor_constant_from_node(
graph_module, linear_weight_node
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import (
AddSimulatedLinearBatchNormFusionQATPass,
)
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.remove_simulated_linear_bn_fusion_qat_pass import (
RemoveSimulatedLinearBatchNormFusionQATPass,
)

__all__ = [
"AddSimulatedLinearBatchNormFusionQATPass",
"RemoveSimulatedLinearBatchNormFusionQATPass",
]

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import torch
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
_unwrap_if_fq,
)
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import (
_get_compute_scale_factor_pattern,
_get_linear_weight_preprocess_pattern,
)
from executorch.backends.nxp.backend.graph_utils import is_batch_norm, is_op_node
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern


_is_add = partial(is_op_node, target_op=torch.ops.aten.add.Tensor)
_is_div = partial(is_op_node, target_op=torch.ops.aten.div.Tensor)
_is_linear = partial(is_op_node, target_op=torch.ops.aten.linear.default)
_is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape)
_is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like)


def _is_denorm_pattern(node: Node) -> bool:
if not _is_div(node):
return False

if not hasattr(node, "users"):
return False

div_users = node.users.keys()
if len(list(div_users)) < 1:
return False

if any(is_batch_norm(user) for user in div_users):
return True

return False


def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule):
matcher = SubgraphMatcher(
pattern.graph,
match_output=False,
match_placeholder=False,
remove_overlapping_matches=True,
ignore_literals=True,
)
matches: list[InternalMatch] = matcher.match(graph_module.graph, node_name_match="")

for match in matches:
last_pattern_node = match.anchors[0]
last_matched_subgraph_node = match.nodes_map[last_pattern_node]
weight = match.placeholder_nodes[0]

last_matched_subgraph_node.replace_all_uses_with(weight)

for node in match.nodes_map.values():
if node not in match.placeholder_nodes:
graph_module.graph.erase_node(node)


def _remove_late_bias_pattern(graph_module: GraphModule, bias_node: Node):
linear_b_users = list(bias_node.users.keys())

if len(linear_b_users) != 2:
return

if _is_zeros_like(linear_b_users[0]):
zeros_node, maybe_reshape_node = linear_b_users
elif _is_zeros_like(linear_b_users[1]):
maybe_reshape_node, zeros_node = linear_b_users
else:
return

if _is_reshape(maybe_reshape_node):
reshape_node = maybe_reshape_node
reshape_users = list(reshape_node.users.keys())

if len(reshape_users) != 1:
return

add_node = reshape_users[0]
else:
# Handles no reshape node when bias is scalar
reshape_node = None
add_node = maybe_reshape_node

if not _is_add(add_node):
return

# Remove zeroed linear bias
zeros_node.replace_all_uses_with(bias_node)
graph_module.graph.erase_node(zeros_node)

# Remove late bias addition
add_node.replace_all_uses_with(add_node.args[0])
graph_module.graph.erase_node(add_node)

if reshape_node:
graph_module.graph.erase_node(reshape_node)


def _remove_denorm_and_late_bias(graph_module: GraphModule):
named_modules = dict(graph_module.named_modules(remove_duplicate=False))

for node in graph_module.graph.nodes:
if not _is_linear(node):
continue

linear_node = node

if len(linear_node.args) <= 2:
continue

linear_bias_fq_or_zeros = _unwrap_if_fq(
linear_node.args[2], named_modules=named_modules
)
has_late_bias = _is_zeros_like(linear_bias_fq_or_zeros)

if has_late_bias:
_remove_late_bias_pattern(
graph_module, bias_node=linear_bias_fq_or_zeros.args[0]
)

for user_node in linear_node.users:
if _is_denorm_pattern(user_node):
if any(is_batch_norm(user) for user in user_node.users.keys()):
user_node.replace_all_uses_with(node)
graph_module.graph.erase_node(user_node)
break


class RemoveSimulatedLinearBatchNormFusionQATPass(PassBase):
"""
In order for QAT to work correctly with fused linear + batch norm operators,
simulated linear + batch norm fusion should be added using AddSimulatedLinearBatchNormFusionQATPass.

After the QAT training, before inserting QDQ nodes, nodes added by the simulated fusion should be removed.
This pass removes all artifacts created by AddSimulatedLinearBatchNormFusionQATPass and reverts
the graph back to the layout before the simulated fusion was applied.
See `add_simulated_linear_bn_fusion_qat_pass.py` for more details.
"""

def call(self, graph_module: GraphModule) -> PassResult | None:
"""
Given a graph of decomposed aten ops, removes nodes corresponding to linear + batch norm fusion.
"""
is_cuda = False

graph_module.graph.eliminate_dead_code()
graph_module.recompile()

_scale_compute_example_inputs = (
torch.randn(1),
torch.randn(1),
)
_preprocess_example_inputs = (
torch.randn(1, 1),
torch.randn(1),
)

scale_pattern = _get_compute_scale_factor_pattern()
scale_match_pattern = _get_aten_graph_module_for_pattern(
pattern=scale_pattern,
example_inputs=_scale_compute_example_inputs,
is_cuda=is_cuda,
)

weight_preprocess_pattern = _get_linear_weight_preprocess_pattern()
weight_preprocess_pattern = _get_aten_graph_module_for_pattern(
pattern=weight_preprocess_pattern,
example_inputs=_preprocess_example_inputs,
is_cuda=is_cuda,
)

_remove_pattern_from_graph(graph_module, pattern=scale_match_pattern)
_remove_pattern_from_graph(graph_module, pattern=weight_preprocess_pattern)
_remove_denorm_and_late_bias(graph_module)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return PassResult(graph_module, True)
34 changes: 34 additions & 0 deletions backends/nxp/backend/graph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torch.fx import Node


batch_norm_target_ops = [
torch.ops.aten.batch_norm.default,
torch.ops.aten.native_batch_norm.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
]


def is_op_node(node: Node, target_op) -> bool:
if isinstance(target_op, list):
target_ops = target_op
else:
target_ops = [target_op]

return (
node is not None
and hasattr(node, "op")
and node.op == "call_function"
and hasattr(node, "target")
and node.target in target_ops
)


def is_batch_norm(node: Node) -> bool:
return is_op_node(node, batch_norm_target_ops)
10 changes: 10 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,16 @@ def get_anchors(
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}

# In order for QAT to be numerically correct, there should be no quantization between
# linear node and batch norm node.
if self.is_qat:
linear_users = linear_node.users
possibly_bn = (
list(linear_users.keys())[0] if len(linear_users) == 1 else None
)
if possibly_bn and _is_batch_norm(possibly_bn):
output = []

return PartitionAnchors(
inputs=[(linear_node, NodeArgsIdx(0))],
weights=[(linear_node, NodeArgsIdx(1))],
Expand Down
10 changes: 10 additions & 0 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from typing import Any, Dict, List, Tuple, Type

import torch

from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import (
AddSimulatedLinearBatchNormFusionQATPass,
RemoveSimulatedLinearBatchNormFusionQATPass,
)
from torch import fx
from torch._ops import OpOverload
from torch.export import ExportedProgram
Expand Down Expand Up @@ -184,12 +189,17 @@ def calibrate_and_quantize(

if is_qat:
m = prepare_qat_pt2e(model, quantizer)
m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module
m = move_exported_model_to_eval(m)
else:
m = prepare_pt2e(model, quantizer)

for data in calibration_inputs:
m(*data)

if is_qat:
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module

m = convert_pt2e(m)

return m
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_mm_conversion(self, _, use_qat: bool):
exported_program,
input_data,
tfl_model=tflite_flatbuffers_model,
atol=1.0,
)

@parameterized.expand([("QAT", True), ("PTQ", False)])
Expand Down
Loading
Loading