From 96b4c9e0d04ebd9c1a457cf5e8e87e3ae6f3ee92 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Thu, 11 Dec 2025 10:20:54 +0100 Subject: [PATCH 01/10] NXP backend: Add pass for inserting simulated Linear+BatchNorm fusion --- ...add_simulated_linear_bn_fusion_qat_pass.py | 410 ++++++++++++++++++ 1 file changed, 410 insertions(+) create mode 100644 backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py diff --git a/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py new file mode 100644 index 00000000000..08b83eb9f0b --- /dev/null +++ b/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py @@ -0,0 +1,410 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( + _unwrap_if_fq, +) +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torchao.quantization.pt2e.export_utils import WrapperModule +from torchao.quantization.pt2e.prepare import _is_activation_post_process_node +from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern + + +_bn_ops = [ + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, +] + + +def _get_linear_weight_preprocess_pattern() -> Callable: + def _preprocess( + weight: torch.Tensor, + scale_factor: torch.Tensor, + ) -> torch.Tensor: + weight_shape = [1] * len(weight.shape) + weight_in_channel_axis = 0 + weight_shape[weight_in_channel_axis] = -1 + return weight * scale_factor.reshape(weight_shape) + + return WrapperModule(_preprocess) + + +def _get_compute_scale_factor_pattern(bn_eps: float = 1e-5) -> Callable: + def _compute_scale( + bn_weight: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + return scale_factor + + return WrapperModule(_compute_scale) + + +def _is_batch_norm(node: Node) -> bool: + return ( + node is not None + and hasattr(node, "op") + and node.op == "call_function" + and hasattr(node, "target") + and node.target in _bn_ops + ) + + +def _is_linear(node: Node) -> bool: + return ( + node is not None + and hasattr(node, "op") + and node.op == "call_function" + and hasattr(node, "target") + and node.target == torch.ops.aten.linear.default + ) + + +def _get_input_nodes(graph_module: GraphModule) -> tuple[Node]: + input_nodes = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + input_nodes.append(node) + + return tuple(input_nodes) + + +def _get_tensor_shape(node: Node) -> torch.Size | None: + if hasattr(node, "meta") and "tensor_meta" in node.meta: + return node.meta["tensor_meta"].shape + + return None + + +def _get_bn_eps(bn_node: Node) -> float: + if "eps" in bn_node.kwargs: + return bn_node.kwargs["eps"] + elif len(bn_node.args) > 7: + return bn_node.args[7] + return 1e-5 + + +def _reinsert_node_before( + graph_module: GraphModule, node: Node, insertion_point: Node +) -> Node: + with graph_module.graph.inserting_before(insertion_point): + node_copy = graph_module.graph.node_copy(node) + node.replace_all_uses_with(node_copy) + graph_module.graph.erase_node(node) + + return node_copy + + +def _insert_linear_output_denorm( + graph_module: GraphModule, linear_node: Node, scale_factor: Node, bn_node: Node +) -> Node: + """ + Inserts div node between linear and batch norm layers. + This denormalizes the linear output by dividing by the scale factor. + """ + with graph_module.graph.inserting_before(bn_node): + div_node = graph_module.graph.create_node( + op="call_function", + target=torch.ops.aten.div.Tensor, + args=(linear_node, scale_factor), + ) + bn_node.update_arg(0, div_node) + + return div_node + + +def _insert_disabled_linear_bias( + graph_module: GraphModule, linear_b_node: Node +) -> Node: + """ + Inserts a zeros_like node to disable the linear layer's bias application. + Expects the application to be placed after the linear output denormalization. + """ + with graph_module.graph.inserting_after(linear_b_node): + zeros_like = graph_module.graph.create_node( + op="call_function", + target=torch.ops.aten.zeros_like, + args=(), + ) + linear_b_node.replace_all_uses_with(zeros_like) + zeros_like.insert_arg(0, linear_b_node) + + return zeros_like + + +def _insert_bias_add_after_denorm( + graph_module: GraphModule, denorm_output_node: Node +) -> Node: + """ + Inserts linear bias application after the denormalization node. + """ + with graph_module.graph.inserting_after(denorm_output_node): + bias_add_node = graph_module.graph.create_node( + op="call_function", + target=torch.ops.aten.add.Tensor, + args=(), + ) + denorm_output_node.replace_all_uses_with(bias_add_node) + + return bias_add_node + + +def _insert_later_bias_application( + graph_module: GraphModule, + denorm_output_node: Node, + linear_node: Node, + linear_b_node: Node, +): + """ + Moves bias application from the normalized space to after simulated denormalization. + """ + _ = _insert_disabled_linear_bias(graph_module, linear_b_node) + + bias_add_node = _insert_bias_add_after_denorm(graph_module, denorm_output_node) + + with graph_module.graph.inserting_before(bias_add_node): + linear_output_shape = linear_node.meta.get("tensor_meta").shape + reshape_target_shape = [1] * len(linear_output_shape) + + bias_redirect_node = linear_b_node + + if len(reshape_target_shape) > 1: + reshape_target_shape[1] = -1 + bias_redirect_node = graph_module.graph.create_node( + op="call_function", + target=torch.ops.aten.reshape, + args=(linear_b_node, reshape_target_shape), + ) + + bias_add_node.insert_arg(0, denorm_output_node) + bias_add_node.insert_arg(1, bias_redirect_node) + + +class AddSimulatedLinearBatchNormFusionQATPass(PassBase): + """ + Batch norm computation can be in some cases fused with preceding linear transformation like conv or linear (see fuse_batch_norm_with_linear_pass.py). + We cannot do this before the QAT training because we would lose the ability to update the batch norm statistics during QAT training. + This pass takes inspiration from already existing mechanics of simulated Conv+BN folding present in TorchAO implementation, and applies it to a torch.nn.Linear. + The implementation can be found in _fuse_conv_bn_qat function [1] that is being used in _prepare_qat_pt2e [2]. + + We simulate the fusion by: + 1. Adding computation of the fused scale factor from batch norm parameters to the graph + 2. Adding pre-processing of the linear weights with this scale factor and applying inverse to its output + 3. (Optionally) Moving bias application after the output denormalization to work in the same space it was originally pre-trained in + 4. Keeping the batch norm operator in the graph for statistics tracking during QAT + + Linear weight scaling (necessary reshape nodes omitted for clarity): + + ┌───────────┐ ┌───────────┐ + │bn_run_var │ │ bn_weight │ + └─────┬─────┘ └─────┬─────┘ + │ ┌─────────────────────────────────┤ + ┌─────────────┐ ┌─────┐ │ ┌──────┴──────┐ ┌─────────────┐ │ + │linear_weight│ │ x │ ├──┤compute_scale│ │linear_weight│ │ + └──────┬──────┘ └──┬──┘ │ └──────┬──────┘ └──────┬──────┘ │ + │ │ │ │ │ │ + │ FQ┌──────────┐FQ │ │ │ ┌─────┐ │ ┌─────┐ │ + └──────┤ linear ├──────┘ │ ├─────┤ mul ├─────┘ │ x │ │ + └─────┬────┘ ──────► │ │ └──┬──┘ └──┬──┘ │ + │ │ │ │ │ │ + ┌───────────┐ │ ┌───────────┐ │ │ │ FQ┌──────────┐FQ │ │ + │ bn_weight │ │ │bn_run_var │ │ │ └───┤ linear ├───┘ │ + └─────┬─────┘ │ └─────┬─────┘ │ │ └─────┬────┘ │ + │ │ │ │ │ │ │ + │ ┌─────┴────┐ │ │ │ ┌──┴──┐ │ + └──────┤batch_norm├──────┘ │ └───────────────┤ div │ │ + └─────┬────┘ │ └──┬──┘ │ + │ │ │ │ + ▼ │ ┌─────┴────┐ │ + └──────────────────────┤batch_norm├─────────┘ + └─────┬────┘ + │ + ▼ + Note: compute_scale := (bn_weight / sqrt(bn_run_var + eps)) + + + Later bias application (necessary reshape nodes omitted for clarity): + + ┌─────────────┐ ┌─────┐ ┌────────────┐ ┌─────────────┐ ┌─────┐ ┌────────────┐ + │linear_weight│ │ x │ │linear_bias │ │linear_weight│ │ x │ │linear_bias │ + └──────┬──────┘ └──┬──┘ └──────┬─────┘ └──────┬──────┘ └──┬──┘ └──────┬─────┘ + │ │ │ │ │ │ + │ FQ┌─────┴────┐FQ │ │ FQ┌─────┴────┐FQ │ + └──────────┤ linear ├─────────┘ └──────────┤ linear ├─ ─ ─ ─ ─┤ + └─────┬────┘ └─────┬────┘ ZEROS │ + │ ───────► │ │ + ┌───────────┐ │ ┌───────────┐ ┌───────────┐ ┌──┴──┐ │ ┌───────────┐ + │ bn_weight │ │ │bn_run_var │ │ bn_weight │ │ add ├───────────┘ │bn_run_var │ + └─────┬─────┘ │ └─────┬─────┘ └─────┬─────┘ └──┬──┘ └─────┬─────┘ + │ │ │ │ │ │ + │ ┌─────┴────┐ │ │ ┌─────┴────┐ │ + └──────────┤batch_norm├─────────┘ └────────┤batch_norm├───────────────────┘ + └─────┬────┘ └─────┬────┘ + │ │ + ▼ ▼ + + [1] https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/quantize_pt2e.py + [2] https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/qat_utils.py + """ + + def call(self, graph_module: GraphModule) -> PassResult | None: + """ + Given a graph of decomposed aten ops, adds linear weights normalization and linear output denormalization based on batch norm stats. + Optionally, applies later Linear bias application if Linear has bias=True selected. + + The normalization follows this equation: + linear_w_fused = linear_w * (gamma / sqrt(var + eps)) + + while `gamma` being the batch norm weights. + + The denormalization is done by dividing the linear layer output by the scale factor: + y_denorm = y / (gamma / sqrt(var + eps)) + + Normalization and denormalization operators should be removed by the + RemoveSimulatedLinearBatchNormFusionQATPass after the QAT training is complete. + """ + modified = False + + named_modules = dict(graph_module.named_modules(remove_duplicate=False)) + + for node in graph_module.graph.nodes: + if not _is_batch_norm(node): + continue + + bn_node = node + bn_in = bn_node.args[0] + + if not _is_linear(bn_in): + continue + + linear_node = bn_in + + linear_w_node_or_fq = linear_node.args[1] + linear_b_node_or_fq = ( + linear_node.args[2] if len(linear_node.args) >= 3 else None + ) + linear_w_node = _unwrap_if_fq( + linear_w_node_or_fq, named_modules=named_modules + ) + linear_b_node = _unwrap_if_fq( + linear_b_node_or_fq, named_modules=named_modules + ) + linear_w_is_quantized = linear_w_node_or_fq != linear_w_node + + bn_w_node = bn_node.args[1] + bn_var_node = bn_node.args[4] + + # BatchNorm(affine=False) + if bn_w_node is None: + continue + + # BatchNorm should not have quantized inputs + if _is_activation_post_process_node( + bn_w_node, named_modules=named_modules + ) or _is_activation_post_process_node( + bn_var_node, named_modules=named_modules + ): + continue + + bn_w_shape = _get_tensor_shape(bn_w_node) + bn_var_shape = _get_tensor_shape(bn_var_node) + linear_w_shape = _get_tensor_shape(linear_w_node) + + if None in (bn_w_shape, bn_var_shape, linear_w_shape): + continue + + scale_factor_example_inputs = ( + torch.randn(bn_w_shape), + torch.randn(bn_var_shape), + ) + norm_linear_w_example_inputs = ( + torch.randn(linear_w_shape), + torch.randn(bn_var_shape), + ) + + # Replacement patterns generation + bn_eps = _get_bn_eps(node) + bn_scale_factor_fn = _get_compute_scale_factor_pattern(bn_eps=bn_eps) + norm_linear_w_fn = _get_linear_weight_preprocess_pattern() + is_cuda = False + bn_scale_factor_fn = _get_aten_graph_module_for_pattern( + pattern=bn_scale_factor_fn, + example_inputs=scale_factor_example_inputs, + is_cuda=is_cuda, + ) + normalize_linear_w_fn = _get_aten_graph_module_for_pattern( + pattern=norm_linear_w_fn, + example_inputs=norm_linear_w_example_inputs, + is_cuda=is_cuda, + ) + bn_w_replacement_node, bn_var_replacement_node = _get_input_nodes( + bn_scale_factor_fn + ) + linear_w_replacement_node, scale_factor_out_replacement_node = ( + _get_input_nodes(normalize_linear_w_fn) + ) + + insertion_point = ( + linear_w_node_or_fq if linear_w_is_quantized else linear_node + ) + + # BN var and W node definition needs to be moved before the linear layer to avoid + # using them before definition. + bn_var_node = _reinsert_node_before( + graph_module, node=bn_var_node, insertion_point=insertion_point + ) + bn_w_node = _reinsert_node_before( + graph_module, node=bn_w_node, insertion_point=insertion_point + ) + + with graph_module.graph.inserting_before(insertion_point): + mapping = { + bn_var_replacement_node: bn_var_node, + bn_w_replacement_node: bn_w_node, + } + scale_factor_out_nodes = graph_module.graph.graph_copy( + bn_scale_factor_fn.graph, val_map=mapping + ) + + mapping = { + linear_w_replacement_node: linear_w_node, + scale_factor_out_replacement_node: scale_factor_out_nodes[0], + } + output_node = graph_module.graph.graph_copy( + normalize_linear_w_fn.graph, val_map=mapping + ) + + if linear_w_is_quantized: + linear_w_node_or_fq.update_arg(0, output_node[0]) + else: + linear_node.update_arg(1, output_node[0]) + + div_node = _insert_linear_output_denorm( + graph_module, + linear_node=linear_node, + scale_factor=scale_factor_out_nodes[0], + bn_node=bn_node, + ) + + if linear_b_node is not None: + _insert_later_bias_application( + graph_module, + denorm_output_node=div_node, + linear_node=linear_node, + linear_b_node=linear_b_node, + ) + + modified = True + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) From 16be8263499fc9dc1750cc3515f5d3a0903f67a4 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Thu, 11 Dec 2025 11:08:05 +0100 Subject: [PATCH 02/10] NXP backend: Add pass for removing simulated Linear+BatchNorm fusion --- ...ove_simulated_linear_bn_fusion_qat_pass.py | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py diff --git a/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py new file mode 100644 index 00000000000..1f7e5409e1a --- /dev/null +++ b/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py @@ -0,0 +1,203 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import ( + _get_compute_scale_factor_pattern, + _get_linear_weight_preprocess_pattern, + _is_linear, +) +from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( + _unwrap_if_fq, +) +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher +from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern + + +def _is_op_node(node: Node, op) -> bool: + return ( + node is not None + and hasattr(node, "op") + and node.op == "call_function" + and hasattr(node, "target") + and node.target == op + ) + + +_is_div = partial(_is_op_node, op=torch.ops.aten.div.Tensor) +_is_add = partial(_is_op_node, op=torch.ops.aten.add.Tensor) +_is_zeros_like = partial(_is_op_node, op=torch.ops.aten.zeros_like) +_is_reshape = partial(_is_op_node, op=torch.ops.aten.reshape) + + +def _is_denorm_pattern(node: Node) -> bool: + if not _is_div(node): + return False + + if not hasattr(node, "users"): + return False + + div_user_ops = [ + user.target for user in node.users.keys() if hasattr(user, "target") + ] + if len(list(div_user_ops)) < 1: + return False + + if torch.ops.aten.batch_norm.default in div_user_ops: + return True + + return False + + +def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule): + matcher = SubgraphMatcher( + pattern.graph, + match_output=False, + match_placeholder=False, + remove_overlapping_matches=True, + ignore_literals=True, + ) + matches: list[InternalMatch] = matcher.match(graph_module.graph, node_name_match="") + + for match in matches: + last_pattern_node = match.anchors[0] + last_matched_subgraph_node = match.nodes_map[last_pattern_node] + weight = match.placeholder_nodes[0] + + last_matched_subgraph_node.replace_all_uses_with(weight) + + for node in match.nodes_map.values(): + if node not in match.placeholder_nodes: + graph_module.graph.erase_node(node) + + +def _remove_late_bias_pattern(graph_module: GraphModule, bias_node: Node): + linear_b_users = list(bias_node.users.keys()) + + if len(linear_b_users) != 2: + return + + if _is_zeros_like(linear_b_users[0]): + zeros_node, maybe_reshape_node = linear_b_users + elif _is_zeros_like(linear_b_users[1]): + maybe_reshape_node, zeros_node = linear_b_users + else: + return + + if _is_reshape(maybe_reshape_node): + reshape_node = maybe_reshape_node + reshape_users = list(reshape_node.users.keys()) + + if len(reshape_users) != 1: + return + + add_node = reshape_users[0] + else: + # Handles no reshape node when bias is scalar + reshape_node = None + add_node = maybe_reshape_node + + if not _is_add(add_node): + return + + # Remove zeroed linear bias + zeros_node.replace_all_uses_with(bias_node) + graph_module.graph.erase_node(zeros_node) + + # Remove late bias addition + add_node.replace_all_uses_with(add_node.args[0]) + graph_module.graph.erase_node(add_node) + + if reshape_node: + graph_module.graph.erase_node(reshape_node) + + +def _remove_denorm_and_late_bias(graph_module: GraphModule): + named_modules = dict(graph_module.named_modules(remove_duplicate=False)) + + for node in graph_module.graph.nodes: + if not _is_linear(node): + continue + + linear_node = node + + if len(linear_node.args) <= 2: + continue + + linear_bias_fq_or_zeros = _unwrap_if_fq( + linear_node.args[2], named_modules=named_modules + ) + has_late_bias = _is_zeros_like(linear_bias_fq_or_zeros) + + if has_late_bias: + _remove_late_bias_pattern( + graph_module, bias_node=linear_bias_fq_or_zeros.args[0] + ) + + for user_node in linear_node.users: + if _is_denorm_pattern(user_node): + users_ops = [user.target for user in user_node.users.keys()] + + if torch.ops.aten.batch_norm.default in users_ops: + user_node.replace_all_uses_with(node) + graph_module.graph.erase_node(user_node) + break + + +class RemoveSimulatedLinearBatchNormFusionQATPass(PassBase): + """ + In order for QAT to work correctly with fused linear + batch norm operators, + simulated linear + batch norm fusion should be added using AddSimulatedLinearBatchNormFusionQATPass. + + After the QAT training, before inserting QDQ nodes, nodes added by the simulated fusion should be removed. + This pass removes all artifacts created by AddSimulatedLinearBatchNormFusionQATPass and reverts + the graph back to the layout before the simulated fusion was applied. + See `add_simulated_linear_bn_fusion_qat_pass.py` for more details. + """ + + def call(self, graph_module: GraphModule) -> PassResult | None: + """ + Given a graph of decomposed aten ops, removes nodes corresponding to linear + batch norm fusion. + """ + is_cuda = False + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + _scale_compute_example_inputs = ( + torch.randn(1), + torch.randn(1), + ) + _preprocess_example_inputs = ( + torch.randn(1, 1), + torch.randn(1), + ) + + scale_pattern = _get_compute_scale_factor_pattern() + scale_match_pattern = _get_aten_graph_module_for_pattern( + pattern=scale_pattern, + example_inputs=_scale_compute_example_inputs, + is_cuda=is_cuda, + ) + + weight_preprocess_pattern = _get_linear_weight_preprocess_pattern() + weight_preprocess_pattern = _get_aten_graph_module_for_pattern( + pattern=weight_preprocess_pattern, + example_inputs=_preprocess_example_inputs, + is_cuda=is_cuda, + ) + + _remove_pattern_from_graph(graph_module, pattern=scale_match_pattern) + _remove_pattern_from_graph(graph_module, pattern=weight_preprocess_pattern) + _remove_denorm_and_late_bias(graph_module) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, True) From b378f74c02205a78f8324bb40e9d01e63584f52d Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Tue, 13 Jan 2026 19:24:16 +0100 Subject: [PATCH 03/10] NXP backend: Add support for fusing in quantized graph --- .../fuse_batch_norm_with_linear_pass.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py b/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py index b6ab4489bb8..d8c22905e10 100644 --- a/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py +++ b/backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py @@ -10,6 +10,21 @@ from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.nn.parameter import Parameter from torch.nn.utils import fuse_linear_bn_weights +from torchao.quantization.pt2e.prepare import _is_activation_post_process_node + + +def _unwrap_if_fq(node: Node, named_modules: dict): + target_node = node + + if _is_activation_post_process_node(node, named_modules): + if len(node.args) >= 1: + target_node = node.args[0] + else: + raise ValueError( + f"FakeQuantize node '{node}' should have at least one argument, but has {len(node.args)}." + ) + + return target_node class FuseBatchNormWithLinearPass(PassBase): @@ -76,6 +91,8 @@ def _is_linear(node_: Node): graph_module, made_changes ) # No batch norm nodes in the model. + named_modules = dict(graph_module.named_modules(remove_duplicate=False)) + for node in graph_module.graph.nodes: if not _is_batch_norm(node): continue # Not BatchNorm. @@ -86,11 +103,18 @@ def _is_linear(node_: Node): continue # Something other than a Linear node comes before the BatchNorm. linear_node = bn_node.args[0] - linear_weight_node = linear_node.args[1] - linear_bias_node = ( + linear_weight_node_or_fq = linear_node.args[1] + linear_bias_node_or_fq = ( linear_node.args[2] if len(linear_node.args) > 2 else None ) + linear_weight_node = _unwrap_if_fq( + linear_weight_node_or_fq, named_modules=named_modules + ) + linear_bias_node = _unwrap_if_fq( + linear_bias_node_or_fq, named_modules=named_modules + ) + linear_w = self._get_tensor_constant_from_node( graph_module, linear_weight_node ) From 1b3839f5f1e22ca87bb831f70dbbdfb1ec162689 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Tue, 13 Jan 2026 19:25:16 +0100 Subject: [PATCH 04/10] NXP backend: Remove linear output quantization in QAT --- backends/nxp/quantizer/patterns.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 2412fd1ea53..3c6fc38bcf9 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -691,6 +691,16 @@ def get_anchors( output = [] activation.meta["quantization_annotation"].input_qspec_map = {} + # In order for QAT to be numerically correct, there should be no quantization between + # linear node and batch norm node. + if self.is_qat: + linear_users = linear_node.users + possibly_bn = ( + list(linear_users.keys())[0] if len(linear_users) == 1 else None + ) + if possibly_bn and _is_batch_norm(possibly_bn): + output = [] + return PartitionAnchors( inputs=[(linear_node, NodeArgsIdx(0))], weights=[(linear_node, NodeArgsIdx(1))], From f9bfb1386ac370c737d72ac4c9510cb446a5af91 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Tue, 13 Jan 2026 19:27:01 +0100 Subject: [PATCH 05/10] NXP backend: Add test for Linear+BatchNorm fusing --- .../ir/edge_passes/test_linear_bn_fusing.py | 197 ++++++++++++++++++ backends/nxp/tests/models.py | 23 ++ 2 files changed, 220 insertions(+) create mode 100644 backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py diff --git a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py new file mode 100644 index 00000000000..edd3cc1df49 --- /dev/null +++ b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py @@ -0,0 +1,197 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.nxp.tests.models as models +import pytest +import torch +from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import ( + AddSimulatedLinearBatchNormFusionQATPass, +) +from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( + FuseBatchNormWithLinearPass, +) +from executorch.backends.nxp.aten_passes.remove_simulated_linear_bn_fusion_qat_pass import ( + RemoveSimulatedLinearBatchNormFusionQATPass, +) + +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.tests.executorch_pipeline import ( + get_random_calibration_inputs, + neutron_target_spec, + to_model_input_spec, +) +from torch.export import export +from torch.fx import Node +from torchao.quantization.pt2e.prepare import _is_activation_post_process_node +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e + + +@pytest.mark.parametrize("input_shape", [(2, 3)]) +@pytest.mark.parametrize("linear_bias", [True, False]) +def test_add_simulated_linear_bn_fusing(input_shape, linear_bias): + calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape)) + input_sample = calibration_inputs[0] + model = models.LinearBNModule( + in_features=input_shape[-1], + out_features=5, + linear_bias=linear_bias, + ) + model.train() + raw_output = model(input_sample[0]) + + exported_model = export(model, input_sample, strict=True) + prepared_model = prepare_qat_pt2e( + exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True) + ) + prepared_model = AddSimulatedLinearBatchNormFusionQATPass()( + prepared_model + ).graph_module + + graph_nodes = list(prepared_model.graph.nodes) + named_modules = dict(prepared_model.named_modules(remove_duplicate=False)) + fake_quantize_output = prepared_model(input_sample[0]) + + expected_number_of_nodes = 23 if linear_bias else 18 + linear_node = next( + ( + n + for n in graph_nodes + if hasattr(n, "target") and n.target == torch.ops.aten.linear.default + ), + None, + ) + + assert len(graph_nodes) == expected_number_of_nodes + + # Assert Linear weight being quantized and "normalized" + assert linear_node is not None + assert all( + _is_activation_post_process_node(n, named_modules) for n in linear_node.args + ) + assert linear_node.args[1].args[0].target == torch.ops.aten.mul.Tensor + + # Assert BatchNorm input being "denormalized" + assert graph_nodes[-3].target == torch.ops.aten.batch_norm.default + if linear_bias: + assert graph_nodes[-3].args[0].target == torch.ops.aten.add.Tensor + add_arg_targets = ( + n.target for n in graph_nodes[-3].args[0].args if hasattr(n, "target") + ) + assert torch.ops.aten.div.Tensor in add_arg_targets + else: + assert graph_nodes[-3].args[0].target == torch.ops.aten.div.Tensor + + assert raw_output.shape == fake_quantize_output.shape + + +@pytest.mark.parametrize("input_shape", [(2, 3)]) +@pytest.mark.parametrize("linear_bias", [True, False]) +def test_full_linear_bn_fusing(input_shape, linear_bias): + # TODO: Add pass for quantizing bias node when Linear has bias=False + if not linear_bias: + pytest.skip( + "Linear with bias=False is not yet supported. " + "The graph currently produces Linear layer without quantized bias which is incorrect." + ) + + calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape)) + input_sample = calibration_inputs[0] + model = models.LinearBNModule( + in_features=input_shape[-1], + out_features=5, + linear_bias=linear_bias, + ) + model.train() + raw_output = model(input_sample[0]) + + exported_model = export(model, input_sample, strict=True) + prepared_model = prepare_qat_pt2e( + exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True) + ) + + prepared_model = AddSimulatedLinearBatchNormFusionQATPass()( + prepared_model + ).graph_module + prepared_model(input_sample[0]) + prepared_model = RemoveSimulatedLinearBatchNormFusionQATPass()( + prepared_model + ).graph_module + prepared_model = FuseBatchNormWithLinearPass()(prepared_model).graph_module + converted_model = convert_pt2e(prepared_model) + + quantized_output = converted_model(input_sample[0]) + graph_nodes = list(converted_model.graph.nodes) + linear_node = graph_nodes[-4] + + def _is_bn(node_: Node) -> bool: + return ( + hasattr(node_, "target") + and node_.target == torch.ops.aten.batch_norm.default + ) + + assert len(graph_nodes) == 11 + + assert not any(_is_bn(node) for node in graph_nodes) + + # Assert linear inputs being quantized + assert linear_node.target == torch.ops.aten.linear.default + assert ( + linear_node.args[0].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + assert ( + linear_node.args[1].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + + # Assert linear outputs being quantized + assert len(linear_node.users) == 1 + assert ( + list(linear_node.users.keys())[0].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + assert raw_output.shape == quantized_output.shape + + +@pytest.mark.parametrize("input_shape", [(2, 3)]) +@pytest.mark.parametrize("linear_bias", [True, False]) +@pytest.mark.parametrize("bn_eps", [1e-5, 1e-6]) +def test_input_output_graph_equivalence(input_shape, linear_bias, bn_eps): + # TODO: Add pass for quantizing bias node when Linear has bias=False + if not linear_bias: + pytest.skip( + "Linear with bias=False is not yet supported. " + "The graph currently produces Linear layer without quantized bias which is incorrect." + ) + + calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape)) + input_sample = calibration_inputs[0] + model = models.LinearBNModule( + in_features=input_shape[-1], + out_features=5, + linear_bias=linear_bias, + bn_eps=bn_eps, + ) + model.eval() + + original_model = export(model, input_sample, strict=True).module() + + processed_model = export(model, input_sample, strict=True).module() + processed_model = AddSimulatedLinearBatchNormFusionQATPass()( + processed_model + ).graph_module + + assert list(processed_model.graph.nodes)[8].args[1] == bn_eps + + processed_model = RemoveSimulatedLinearBatchNormFusionQATPass()( + processed_model + ).graph_module + + assert list(processed_model.graph.nodes)[-2].args[7] == bn_eps + assert torch.equal( + original_model(input_sample[0]), processed_model(input_sample[0]) + ) + assert len(original_model.graph.nodes) == len(processed_model.graph.nodes) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index d8ca09a1bd9..e0eec619ce8 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -481,6 +481,29 @@ def forward(self, x): return self.bn(x) +class LinearBNModule(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + linear_bias: bool, + bn_eps: float = 1e-5, + act: nn.Module | None = None, + ): + super().__init__() + + self.linear = torch.nn.Linear( + in_features=in_features, out_features=out_features, bias=linear_bias + ) + self.bn = torch.nn.BatchNorm1d(out_features, eps=bn_eps) + self.act = act + + def forward(self, x): + x = self.linear(x) + x = self.bn(x) + return self.act(x) if self.act is not None else x + + class MulTensorModule(torch.nn.Module): def __init__(self): super().__init__() From 2115c52d04f97bf273c72a973b9b00a718036cc5 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Tue, 10 Feb 2026 15:24:16 +0100 Subject: [PATCH 06/10] NXP backend: Extract generally used functions into utils file --- ...add_simulated_linear_bn_fusion_qat_pass.py | 32 ++--------------- ...ove_simulated_linear_bn_fusion_qat_pass.py | 21 ++++-------- backends/nxp/backend/graph_utils.py | 34 +++++++++++++++++++ .../ir/edge_passes/test_linear_bn_fusing.py | 10 ++---- 4 files changed, 45 insertions(+), 52 deletions(-) create mode 100644 backends/nxp/backend/graph_utils.py diff --git a/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py index 08b83eb9f0b..53aa51796a4 100644 --- a/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py +++ b/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py @@ -9,6 +9,7 @@ from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( _unwrap_if_fq, ) +from executorch.backends.nxp.backend.graph_utils import is_batch_norm, is_op_node from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult from torchao.quantization.pt2e.export_utils import WrapperModule @@ -16,13 +17,6 @@ from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern -_bn_ops = [ - torch.ops.aten.batch_norm.default, - torch.ops.aten.native_batch_norm.default, - torch.ops.aten._native_batch_norm_legit_no_training.default, -] - - def _get_linear_weight_preprocess_pattern() -> Callable: def _preprocess( weight: torch.Tensor, @@ -48,26 +42,6 @@ def _compute_scale( return WrapperModule(_compute_scale) -def _is_batch_norm(node: Node) -> bool: - return ( - node is not None - and hasattr(node, "op") - and node.op == "call_function" - and hasattr(node, "target") - and node.target in _bn_ops - ) - - -def _is_linear(node: Node) -> bool: - return ( - node is not None - and hasattr(node, "op") - and node.op == "call_function" - and hasattr(node, "target") - and node.target == torch.ops.aten.linear.default - ) - - def _get_input_nodes(graph_module: GraphModule) -> tuple[Node]: input_nodes = [] for node in graph_module.graph.nodes: @@ -276,13 +250,13 @@ def call(self, graph_module: GraphModule) -> PassResult | None: named_modules = dict(graph_module.named_modules(remove_duplicate=False)) for node in graph_module.graph.nodes: - if not _is_batch_norm(node): + if not is_batch_norm(node): continue bn_node = node bn_in = bn_node.args[0] - if not _is_linear(bn_in): + if not is_op_node(bn_in, torch.ops.aten.linear.default): continue linear_node = bn_in diff --git a/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py index 1f7e5409e1a..b72a25c6cfc 100644 --- a/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py +++ b/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py @@ -9,31 +9,22 @@ from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import ( _get_compute_scale_factor_pattern, _get_linear_weight_preprocess_pattern, - _is_linear, ) from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( _unwrap_if_fq, ) +from executorch.backends.nxp.backend.graph_utils import is_op_node from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern -def _is_op_node(node: Node, op) -> bool: - return ( - node is not None - and hasattr(node, "op") - and node.op == "call_function" - and hasattr(node, "target") - and node.target == op - ) - - -_is_div = partial(_is_op_node, op=torch.ops.aten.div.Tensor) -_is_add = partial(_is_op_node, op=torch.ops.aten.add.Tensor) -_is_zeros_like = partial(_is_op_node, op=torch.ops.aten.zeros_like) -_is_reshape = partial(_is_op_node, op=torch.ops.aten.reshape) +_is_add = partial(is_op_node, target_op=torch.ops.aten.add.Tensor) +_is_div = partial(is_op_node, target_op=torch.ops.aten.div.Tensor) +_is_linear = partial(is_op_node, target_op=torch.ops.aten.linear.default) +_is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape) +_is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like) def _is_denorm_pattern(node: Node) -> bool: diff --git a/backends/nxp/backend/graph_utils.py b/backends/nxp/backend/graph_utils.py new file mode 100644 index 00000000000..7de92d573c0 --- /dev/null +++ b/backends/nxp/backend/graph_utils.py @@ -0,0 +1,34 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torch.fx import Node + + +batch_norm_target_ops = [ + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, +] + + +def is_op_node(node: Node, target_op) -> bool: + if isinstance(target_op, list): + target_ops = target_op + else: + target_ops = [target_op] + + return ( + node is not None + and hasattr(node, "op") + and node.op == "call_function" + and hasattr(node, "target") + and node.target in target_ops + ) + + +def is_batch_norm(node: Node) -> bool: + return is_op_node(node, batch_norm_target_ops) diff --git a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py index edd3cc1df49..f6fd72b99f4 100644 --- a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py +++ b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py @@ -15,6 +15,7 @@ from executorch.backends.nxp.aten_passes.remove_simulated_linear_bn_fusion_qat_pass import ( RemoveSimulatedLinearBatchNormFusionQATPass, ) +from executorch.backends.nxp.backend.graph_utils import is_batch_norm from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer from executorch.backends.nxp.tests.executorch_pipeline import ( @@ -23,7 +24,6 @@ to_model_input_spec, ) from torch.export import export -from torch.fx import Node from torchao.quantization.pt2e.prepare import _is_activation_post_process_node from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e @@ -125,15 +125,9 @@ def test_full_linear_bn_fusing(input_shape, linear_bias): graph_nodes = list(converted_model.graph.nodes) linear_node = graph_nodes[-4] - def _is_bn(node_: Node) -> bool: - return ( - hasattr(node_, "target") - and node_.target == torch.ops.aten.batch_norm.default - ) - assert len(graph_nodes) == 11 - assert not any(_is_bn(node) for node in graph_nodes) + assert not any(is_batch_norm(node) for node in graph_nodes) # Assert linear inputs being quantized assert linear_node.target == torch.ops.aten.linear.default From 506ba8190fd7859cf5d0a8c933ed2d27fb556616 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Tue, 10 Feb 2026 15:45:20 +0100 Subject: [PATCH 07/10] NXP backend: Relocate Linear+BN related passes --- .../simulated_linear_bn_fusion_passes/__init__.py | 11 +++++++++++ .../add_simulated_linear_bn_fusion_qat_pass.py | 2 +- .../remove_simulated_linear_bn_fusion_qat_pass.py | 8 ++++---- .../nxp/tests/ir/edge_passes/test_linear_bn_fusing.py | 6 ++---- 4 files changed, 18 insertions(+), 9 deletions(-) create mode 100644 backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/__init__.py rename backends/nxp/aten_passes/{ => simulated_linear_bn_fusion_passes}/add_simulated_linear_bn_fusion_qat_pass.py (99%) rename backends/nxp/aten_passes/{ => simulated_linear_bn_fusion_passes}/remove_simulated_linear_bn_fusion_qat_pass.py (98%) diff --git a/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/__init__.py b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/__init__.py new file mode 100644 index 00000000000..b7c93bf6496 --- /dev/null +++ b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/__init__.py @@ -0,0 +1,11 @@ +from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import ( + AddSimulatedLinearBatchNormFusionQATPass, +) +from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.remove_simulated_linear_bn_fusion_qat_pass import ( + RemoveSimulatedLinearBatchNormFusionQATPass, +) + +__all__ = [ + "AddSimulatedLinearBatchNormFusionQATPass", + "RemoveSimulatedLinearBatchNormFusionQATPass", +] diff --git a/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/add_simulated_linear_bn_fusion_qat_pass.py similarity index 99% rename from backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py rename to backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/add_simulated_linear_bn_fusion_qat_pass.py index 53aa51796a4..94e8553f2b3 100644 --- a/backends/nxp/aten_passes/add_simulated_linear_bn_fusion_qat_pass.py +++ b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/add_simulated_linear_bn_fusion_qat_pass.py @@ -237,7 +237,7 @@ def call(self, graph_module: GraphModule) -> PassResult | None: The normalization follows this equation: linear_w_fused = linear_w * (gamma / sqrt(var + eps)) - while `gamma` being the batch norm weights. + where `gamma` is the batch norm weight. The denormalization is done by dividing the linear layer output by the scale factor: y_denorm = y / (gamma / sqrt(var + eps)) diff --git a/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py similarity index 98% rename from backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py rename to backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py index b72a25c6cfc..5bdb1f72379 100644 --- a/backends/nxp/aten_passes/remove_simulated_linear_bn_fusion_qat_pass.py +++ b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py @@ -6,13 +6,13 @@ from functools import partial import torch -from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import ( - _get_compute_scale_factor_pattern, - _get_linear_weight_preprocess_pattern, -) from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( _unwrap_if_fq, ) +from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import ( + _get_compute_scale_factor_pattern, + _get_linear_weight_preprocess_pattern, +) from executorch.backends.nxp.backend.graph_utils import is_op_node from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py index f6fd72b99f4..a5be4ef77dc 100644 --- a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py +++ b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py @@ -6,13 +6,11 @@ import executorch.backends.nxp.tests.models as models import pytest import torch -from executorch.backends.nxp.aten_passes.add_simulated_linear_bn_fusion_qat_pass import ( - AddSimulatedLinearBatchNormFusionQATPass, -) from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( FuseBatchNormWithLinearPass, ) -from executorch.backends.nxp.aten_passes.remove_simulated_linear_bn_fusion_qat_pass import ( +from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import ( + AddSimulatedLinearBatchNormFusionQATPass, RemoveSimulatedLinearBatchNormFusionQATPass, ) from executorch.backends.nxp.backend.graph_utils import is_batch_norm From 4ee316d149759b57dc0daccbd44b1a16d1f0e130 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Tue, 10 Feb 2026 16:30:15 +0100 Subject: [PATCH 08/10] NXP backend: Add Linear+BN fusing to the quantization pipeline --- backends/nxp/quantizer/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index 459f31ec7da..bebf3689626 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -13,6 +13,11 @@ from typing import Any, Dict, List, Tuple, Type import torch + +from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import ( + AddSimulatedLinearBatchNormFusionQATPass, + RemoveSimulatedLinearBatchNormFusionQATPass, +) from torch import fx from torch._ops import OpOverload from torch.export import ExportedProgram @@ -184,12 +189,17 @@ def calibrate_and_quantize( if is_qat: m = prepare_qat_pt2e(model, quantizer) + m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module m = move_exported_model_to_eval(m) else: m = prepare_pt2e(model, quantizer) for data in calibration_inputs: m(*data) + + if is_qat: + m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module + m = convert_pt2e(m) return m From d0897654c98392a601550827a2b1b3862e50099f Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Wed, 11 Feb 2026 09:17:39 +0100 Subject: [PATCH 09/10] NXP backend: Unify BatchNorm op checks --- .../remove_simulated_linear_bn_fusion_qat_pass.py | 14 +++++--------- .../tests/ir/edge_passes/test_linear_bn_fusing.py | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py index 5bdb1f72379..71ea7083acf 100644 --- a/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py +++ b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py @@ -13,7 +13,7 @@ _get_compute_scale_factor_pattern, _get_linear_weight_preprocess_pattern, ) -from executorch.backends.nxp.backend.graph_utils import is_op_node +from executorch.backends.nxp.backend.graph_utils import is_batch_norm, is_op_node from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher @@ -34,13 +34,11 @@ def _is_denorm_pattern(node: Node) -> bool: if not hasattr(node, "users"): return False - div_user_ops = [ - user.target for user in node.users.keys() if hasattr(user, "target") - ] - if len(list(div_user_ops)) < 1: + div_users = node.users.keys() + if len(list(div_users)) < 1: return False - if torch.ops.aten.batch_norm.default in div_user_ops: + if any(is_batch_norm(user) for user in div_users): return True return False @@ -133,9 +131,7 @@ def _remove_denorm_and_late_bias(graph_module: GraphModule): for user_node in linear_node.users: if _is_denorm_pattern(user_node): - users_ops = [user.target for user in user_node.users.keys()] - - if torch.ops.aten.batch_norm.default in users_ops: + if any(is_batch_norm(user) for user in user_node.users.keys()): user_node.replace_all_uses_with(node) graph_module.graph.erase_node(user_node) break diff --git a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py index a5be4ef77dc..c93ad066832 100644 --- a/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py +++ b/backends/nxp/tests/ir/edge_passes/test_linear_bn_fusing.py @@ -71,7 +71,7 @@ def test_add_simulated_linear_bn_fusing(input_shape, linear_bias): assert linear_node.args[1].args[0].target == torch.ops.aten.mul.Tensor # Assert BatchNorm input being "denormalized" - assert graph_nodes[-3].target == torch.ops.aten.batch_norm.default + assert is_batch_norm(graph_nodes[-3]) if linear_bias: assert graph_nodes[-3].args[0].target == torch.ops.aten.add.Tensor add_arg_targets = ( From d70a29edc7e2da8bd402c985c00ad8e52761c862 Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Wed, 11 Feb 2026 10:09:23 +0100 Subject: [PATCH 10/10] NXP backend: Adjust MM converter test tolerance --- .../nxp/tests/ir/converter/node_converter/test_mm_converter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py index 962a4f4b0c1..60dbfd1b215 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py @@ -60,6 +60,7 @@ def test_mm_conversion(self, _, use_qat: bool): exported_program, input_data, tfl_model=tflite_flatbuffers_model, + atol=1.0, ) @parameterized.expand([("QAT", True), ("PTQ", False)])