1010from torch .fx .passes .infra .pass_base import PassBase , PassResult
1111from torch .nn .parameter import Parameter
1212from torch .nn .utils import fuse_linear_bn_weights
13+ from torchao .quantization .pt2e .prepare import _is_activation_post_process_node
14+
15+
16+ def _unwrap_if_fq (node : Node , named_modules : dict ):
17+ target_node = node
18+
19+ if _is_activation_post_process_node (node , named_modules ):
20+ if len (node .args ) >= 1 :
21+ target_node = node .args [0 ]
22+ else :
23+ raise ValueError (
24+ f"FakeQuantize node '{ node } ' should have at least one argument, but has { len (node .args )} ."
25+ )
26+
27+ return target_node
1328
1429
1530class FuseBatchNormWithLinearPass (PassBase ):
@@ -76,6 +91,8 @@ def _is_linear(node_: Node):
7691 graph_module , made_changes
7792 ) # No batch norm nodes in the model.
7893
94+ named_modules = dict (graph_module .named_modules (remove_duplicate = False ))
95+
7996 for node in graph_module .graph .nodes :
8097 if not _is_batch_norm (node ):
8198 continue # Not BatchNorm.
@@ -86,11 +103,18 @@ def _is_linear(node_: Node):
86103 continue # Something other than a Linear node comes before the BatchNorm.
87104
88105 linear_node = bn_node .args [0 ]
89- linear_weight_node = linear_node .args [1 ]
90- linear_bias_node = (
106+ linear_weight_node_or_fq = linear_node .args [1 ]
107+ linear_bias_node_or_fq = (
91108 linear_node .args [2 ] if len (linear_node .args ) > 2 else None
92109 )
93110
111+ linear_weight_node = _unwrap_if_fq (
112+ linear_weight_node_or_fq , named_modules = named_modules
113+ )
114+ linear_bias_node = _unwrap_if_fq (
115+ linear_bias_node_or_fq , named_modules = named_modules
116+ )
117+
94118 linear_w = self ._get_tensor_constant_from_node (
95119 graph_module , linear_weight_node
96120 )
0 commit comments