-
Notifications
You must be signed in to change notification settings - Fork 834
NXP backend: Linear + BatchNorm QAT fusing #16623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
StrycekSimon
wants to merge
10
commits into
pytorch:main
Choose a base branch
from
nxp-upstream:feature/EIEX-641-create-a-pass-to-fuse-linear-batchnorm-after-qat-quantization
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
96b4c9e
NXP backend: Add pass for inserting simulated Linear+BatchNorm fusion
StrycekSimon 16be826
NXP backend: Add pass for removing simulated Linear+BatchNorm fusion
StrycekSimon b378f74
NXP backend: Add support for fusing in quantized graph
StrycekSimon 1b3839f
NXP backend: Remove linear output quantization in QAT
StrycekSimon f9bfb13
NXP backend: Add test for Linear+BatchNorm fusing
StrycekSimon 2115c52
NXP backend: Extract generally used functions into utils file
StrycekSimon 506ba81
NXP backend: Relocate Linear+BN related passes
StrycekSimon 4ee316d
NXP backend: Add Linear+BN fusing to the quantization pipeline
StrycekSimon d089765
NXP backend: Unify BatchNorm op checks
StrycekSimon d70a29e
NXP backend: Adjust MM converter test tolerance
StrycekSimon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import ( | ||
| AddSimulatedLinearBatchNormFusionQATPass, | ||
| ) | ||
| from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.remove_simulated_linear_bn_fusion_qat_pass import ( | ||
| RemoveSimulatedLinearBatchNormFusionQATPass, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "AddSimulatedLinearBatchNormFusionQATPass", | ||
| "RemoveSimulatedLinearBatchNormFusionQATPass", | ||
| ] |
384 changes: 384 additions & 0 deletions
384
.../aten_passes/simulated_linear_bn_fusion_passes/add_simulated_linear_bn_fusion_qat_pass.py
Large diffs are not rendered by default.
Oops, something went wrong.
190 changes: 190 additions & 0 deletions
190
...en_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py
StrycekSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| # Copyright 2026 NXP | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from functools import partial | ||
|
|
||
| import torch | ||
| from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( | ||
| _unwrap_if_fq, | ||
| ) | ||
| from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import ( | ||
| _get_compute_scale_factor_pattern, | ||
| _get_linear_weight_preprocess_pattern, | ||
| ) | ||
| from executorch.backends.nxp.backend.graph_utils import is_batch_norm, is_op_node | ||
| from torch.fx import GraphModule, Node | ||
| from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
| from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher | ||
| from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern | ||
|
|
||
|
|
||
| _is_add = partial(is_op_node, target_op=torch.ops.aten.add.Tensor) | ||
| _is_div = partial(is_op_node, target_op=torch.ops.aten.div.Tensor) | ||
| _is_linear = partial(is_op_node, target_op=torch.ops.aten.linear.default) | ||
| _is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape) | ||
| _is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like) | ||
|
|
||
|
|
||
| def _is_denorm_pattern(node: Node) -> bool: | ||
| if not _is_div(node): | ||
| return False | ||
|
|
||
| if not hasattr(node, "users"): | ||
| return False | ||
|
|
||
| div_users = node.users.keys() | ||
| if len(list(div_users)) < 1: | ||
| return False | ||
|
|
||
| if any(is_batch_norm(user) for user in div_users): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule): | ||
| matcher = SubgraphMatcher( | ||
| pattern.graph, | ||
| match_output=False, | ||
| match_placeholder=False, | ||
| remove_overlapping_matches=True, | ||
| ignore_literals=True, | ||
| ) | ||
| matches: list[InternalMatch] = matcher.match(graph_module.graph, node_name_match="") | ||
|
|
||
| for match in matches: | ||
| last_pattern_node = match.anchors[0] | ||
| last_matched_subgraph_node = match.nodes_map[last_pattern_node] | ||
| weight = match.placeholder_nodes[0] | ||
|
|
||
| last_matched_subgraph_node.replace_all_uses_with(weight) | ||
|
|
||
| for node in match.nodes_map.values(): | ||
| if node not in match.placeholder_nodes: | ||
| graph_module.graph.erase_node(node) | ||
|
|
||
|
|
||
| def _remove_late_bias_pattern(graph_module: GraphModule, bias_node: Node): | ||
| linear_b_users = list(bias_node.users.keys()) | ||
|
|
||
| if len(linear_b_users) != 2: | ||
| return | ||
|
|
||
| if _is_zeros_like(linear_b_users[0]): | ||
| zeros_node, maybe_reshape_node = linear_b_users | ||
| elif _is_zeros_like(linear_b_users[1]): | ||
| maybe_reshape_node, zeros_node = linear_b_users | ||
| else: | ||
| return | ||
|
|
||
| if _is_reshape(maybe_reshape_node): | ||
| reshape_node = maybe_reshape_node | ||
| reshape_users = list(reshape_node.users.keys()) | ||
|
|
||
| if len(reshape_users) != 1: | ||
| return | ||
|
|
||
| add_node = reshape_users[0] | ||
| else: | ||
| # Handles no reshape node when bias is scalar | ||
| reshape_node = None | ||
| add_node = maybe_reshape_node | ||
|
|
||
| if not _is_add(add_node): | ||
| return | ||
|
|
||
| # Remove zeroed linear bias | ||
| zeros_node.replace_all_uses_with(bias_node) | ||
| graph_module.graph.erase_node(zeros_node) | ||
|
|
||
| # Remove late bias addition | ||
| add_node.replace_all_uses_with(add_node.args[0]) | ||
| graph_module.graph.erase_node(add_node) | ||
|
|
||
| if reshape_node: | ||
| graph_module.graph.erase_node(reshape_node) | ||
|
|
||
|
|
||
| def _remove_denorm_and_late_bias(graph_module: GraphModule): | ||
| named_modules = dict(graph_module.named_modules(remove_duplicate=False)) | ||
|
|
||
| for node in graph_module.graph.nodes: | ||
| if not _is_linear(node): | ||
| continue | ||
|
|
||
| linear_node = node | ||
|
|
||
| if len(linear_node.args) <= 2: | ||
| continue | ||
|
|
||
| linear_bias_fq_or_zeros = _unwrap_if_fq( | ||
| linear_node.args[2], named_modules=named_modules | ||
| ) | ||
| has_late_bias = _is_zeros_like(linear_bias_fq_or_zeros) | ||
|
|
||
| if has_late_bias: | ||
| _remove_late_bias_pattern( | ||
| graph_module, bias_node=linear_bias_fq_or_zeros.args[0] | ||
| ) | ||
|
|
||
| for user_node in linear_node.users: | ||
| if _is_denorm_pattern(user_node): | ||
| if any(is_batch_norm(user) for user in user_node.users.keys()): | ||
| user_node.replace_all_uses_with(node) | ||
| graph_module.graph.erase_node(user_node) | ||
| break | ||
|
|
||
|
|
||
| class RemoveSimulatedLinearBatchNormFusionQATPass(PassBase): | ||
| """ | ||
| In order for QAT to work correctly with fused linear + batch norm operators, | ||
| simulated linear + batch norm fusion should be added using AddSimulatedLinearBatchNormFusionQATPass. | ||
|
|
||
| After the QAT training, before inserting QDQ nodes, nodes added by the simulated fusion should be removed. | ||
| This pass removes all artifacts created by AddSimulatedLinearBatchNormFusionQATPass and reverts | ||
| the graph back to the layout before the simulated fusion was applied. | ||
| See `add_simulated_linear_bn_fusion_qat_pass.py` for more details. | ||
| """ | ||
|
|
||
| def call(self, graph_module: GraphModule) -> PassResult | None: | ||
| """ | ||
| Given a graph of decomposed aten ops, removes nodes corresponding to linear + batch norm fusion. | ||
| """ | ||
| is_cuda = False | ||
|
|
||
| graph_module.graph.eliminate_dead_code() | ||
| graph_module.recompile() | ||
|
|
||
| _scale_compute_example_inputs = ( | ||
| torch.randn(1), | ||
| torch.randn(1), | ||
| ) | ||
| _preprocess_example_inputs = ( | ||
| torch.randn(1, 1), | ||
| torch.randn(1), | ||
| ) | ||
|
|
||
| scale_pattern = _get_compute_scale_factor_pattern() | ||
| scale_match_pattern = _get_aten_graph_module_for_pattern( | ||
| pattern=scale_pattern, | ||
| example_inputs=_scale_compute_example_inputs, | ||
| is_cuda=is_cuda, | ||
| ) | ||
|
|
||
| weight_preprocess_pattern = _get_linear_weight_preprocess_pattern() | ||
| weight_preprocess_pattern = _get_aten_graph_module_for_pattern( | ||
| pattern=weight_preprocess_pattern, | ||
| example_inputs=_preprocess_example_inputs, | ||
| is_cuda=is_cuda, | ||
| ) | ||
|
|
||
| _remove_pattern_from_graph(graph_module, pattern=scale_match_pattern) | ||
| _remove_pattern_from_graph(graph_module, pattern=weight_preprocess_pattern) | ||
| _remove_denorm_and_late_bias(graph_module) | ||
|
|
||
| graph_module.graph.eliminate_dead_code() | ||
| graph_module.recompile() | ||
|
|
||
| return PassResult(graph_module, True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Copyright 2026 NXP | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
|
|
||
| from torch.fx import Node | ||
|
|
||
|
|
||
| batch_norm_target_ops = [ | ||
| torch.ops.aten.batch_norm.default, | ||
| torch.ops.aten.native_batch_norm.default, | ||
| torch.ops.aten._native_batch_norm_legit_no_training.default, | ||
| ] | ||
|
|
||
|
|
||
| def is_op_node(node: Node, target_op) -> bool: | ||
| if isinstance(target_op, list): | ||
| target_ops = target_op | ||
| else: | ||
| target_ops = [target_op] | ||
|
|
||
| return ( | ||
| node is not None | ||
| and hasattr(node, "op") | ||
| and node.op == "call_function" | ||
| and hasattr(node, "target") | ||
| and node.target in target_ops | ||
| ) | ||
|
|
||
|
|
||
| def is_batch_norm(node: Node) -> bool: | ||
| return is_op_node(node, batch_norm_target_ops) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.