Skip to content

Commit 51bd521

Browse files
committed
NXP backend: Add Linear+BN fusing to the quantization pipeline
1 parent 155f003 commit 51bd521

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

backends/nxp/quantizer/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from typing import Any, Dict, List, Tuple, Type
1414

1515
import torch
16+
17+
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import (
18+
AddSimulatedLinearBatchNormFusionQATPass,
19+
RemoveSimulatedLinearBatchNormFusionQATPass,
20+
)
1621
from torch import fx
1722
from torch._ops import OpOverload
1823
from torch.export import ExportedProgram
@@ -184,12 +189,17 @@ def calibrate_and_quantize(
184189

185190
if is_qat:
186191
m = prepare_qat_pt2e(model, quantizer)
192+
m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module
187193
m = move_exported_model_to_eval(m)
188194
else:
189195
m = prepare_pt2e(model, quantizer)
190196

191197
for data in calibration_inputs:
192198
m(*data)
199+
200+
if is_qat:
201+
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module
202+
193203
m = convert_pt2e(m)
194204

195205
return m

0 commit comments

Comments
 (0)