Skip to content

Commit 309784c

Browse files
author
ssjia
committed
[executorch][arm] Add Tosa to LLM extension
This diff is to facilitate the landing of #15556 Differential Revision: [D92280949](https://our.internmc.facebook.com/intern/diff/D92280949/) [ghstack-poisoned]
1 parent 477867a commit 309784c

File tree

8 files changed

+121
-4
lines changed

8 files changed

+121
-4
lines changed

backends/arm/quantizer/TARGETS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ runtime.python_library(
1717
deps = [
1818
":arm_quantizer_utils",
1919
":quantization_annotator",
20+
"//executorch/backends/arm:constants",
21+
"//executorch/backends/arm:ethosu",
22+
"//executorch/backends/arm:vgf",
23+
"//executorch/backends/arm/tosa:specification",
24+
"//executorch/backends/arm:arm_compile_spec",
2025
"//caffe2:torch",
2126
"//executorch/exir:lib",
2227
"//pytorch/ao:torchao",

examples/models/llama/export_llama_lib.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2025 Arm Limited and/or its affiliates.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -37,6 +37,7 @@
3737
get_mps_partitioner,
3838
get_openvino_partitioner,
3939
get_qnn_partitioner,
40+
get_tosa_partitioner,
4041
get_vulkan_partitioner,
4142
get_xnnpack_partitioner,
4243
)
@@ -46,6 +47,7 @@
4647
get_pt2e_quantization_params,
4748
get_pt2e_quantizers,
4849
get_qnn_quantizer,
50+
get_tosa_quantizer,
4951
get_vulkan_quantizer,
5052
)
5153
from executorch.util.activation_memory_profiler import generate_memory_trace
@@ -210,6 +212,7 @@ def build_args_parser() -> argparse.ArgumentParser:
210212
"coreml_baseline_8a_c8w",
211213
"coreml_baseline_8a_c4w",
212214
"vulkan_8w",
215+
"tosa_8a8w",
213216
],
214217
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
215218
)
@@ -788,6 +791,11 @@ def get_quantizer_and_quant_params(llm_config):
788791
llm_config.quantization.pt2e_quantize.value
789792
)
790793
quantizers.append(coreml_quantizer)
794+
if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize:
795+
tosa_quantizer = get_tosa_quantizer(
796+
llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value
797+
)
798+
quantizers.append(tosa_quantizer)
791799
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
792800
assert (
793801
len(quantizers) == 0
@@ -930,6 +938,32 @@ def _to_edge_and_lower_llama_openvino(
930938
return builder.to_executorch(passes=additional_passes)
931939

932940

941+
def _to_edge_and_lower_llama_tosa(
942+
builder_exported,
943+
modelname,
944+
quantizers,
945+
additional_passes,
946+
tosa_spec,
947+
verbose: bool = False,
948+
) -> LLMEdgeManager:
949+
950+
logging.info("Lowering model using TOSA partitioner")
951+
952+
partitioners = []
953+
partitioners.append(get_tosa_partitioner(tosa_spec))
954+
955+
modelname = f"tosa_{modelname}"
956+
957+
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
958+
partitioners
959+
)
960+
961+
if verbose:
962+
print_delegation_info(builder.edge_manager.exported_program().graph_module)
963+
964+
return builder.to_executorch(passes=additional_passes)
965+
966+
933967
def _to_edge_and_lower_llama( # noqa: C901
934968
builder_exported,
935969
modelname,
@@ -1119,7 +1153,10 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11191153
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
11201154

11211155
# export_to_edge
1122-
builder_exported = _prepare_for_llama_export(llm_config).export()
1156+
builder_manager = _prepare_for_llama_export(llm_config)
1157+
if llm_config.backend.tosa.enabled:
1158+
builder_manager.skip_dim_order = False
1159+
builder_exported = builder_manager.export()
11231160
builder_exported.run_canonical_optimizations()
11241161
modelname = builder_exported.modelname
11251162

@@ -1162,6 +1199,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11621199
openvino_device=llm_config.backend.openvino.device,
11631200
verbose=llm_config.debug.verbose,
11641201
)
1202+
elif llm_config.backend.tosa.enabled:
1203+
builder = _to_edge_and_lower_llama_tosa(
1204+
builder_exported,
1205+
modelname,
1206+
quantizers,
1207+
additional_passes,
1208+
llm_config.backend.tosa.version,
1209+
verbose=llm_config.debug.verbose,
1210+
)
11651211
else:
11661212
builder = _to_edge_and_lower_llama(
11671213
builder_exported,

examples/models/llama/tests/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ fbcode_target(_kind = python_unittest,
2222
],
2323
deps = [
2424
"//caffe2:torch",
25+
"//executorch/backends/arm/quantizer:lib",
2526
"//executorch/examples/models/llama:export_library",
2627
"//executorch/examples/models/llama:llama_transformer",
2728
"//pytorch/ao:torchao",
@@ -97,6 +98,7 @@ fbcode_target(_kind = python_unittest,
9798
],
9899
deps = [
99100
"//caffe2:torch",
101+
"//executorch/backends/arm/quantizer:lib",
100102
"//executorch/examples/models/llama:export_library",
101103
"//executorch/examples/models/llama:llama_transformer",
102104
"//executorch/extension/pybindings:portable_lib",

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

78
import unittest
89

10+
from executorch.backends.arm.quantizer.arm_quantizer import TOSAQuantizer
11+
912
from executorch.devtools.backend_debug import get_delegation_info
1013
from executorch.examples.models.llama.export_llama_lib import (
1114
_export_llama,
1215
build_args_parser,
16+
get_quantizer_and_quant_params,
1317
)
14-
from executorch.extension.llm.export.config.llm_config import LlmConfig
18+
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize
1519

1620
UNWANTED_OPS = [
1721
"aten_permute_copy_default",
@@ -48,3 +52,17 @@ def test_has_expected_ops_and_op_counts(self):
4852

4953
for op, _op_info in delegation_info.delegation_by_operator.items():
5054
self.assertTrue(op not in UNWANTED_OPS)
55+
56+
def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
57+
llm_config = LlmConfig()
58+
llm_config.backend.tosa.enabled = True
59+
llm_config.quantization.pt2e_quantize = Pt2eQuantize.tosa_8a8w
60+
61+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
62+
llm_config
63+
)
64+
65+
self.assertIsNone(pt2e_quant_params)
66+
self.assertIsNone(quant_dtype)
67+
self.assertEqual(len(quantizers), 1)
68+
self.assertIsInstance(quantizers[0], TOSAQuantizer)

extension/llm/export/builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -96,6 +97,7 @@ def __init__(
9697
dynamic_shapes: Optional[Any] = None,
9798
save_exported_program: bool = False,
9899
generate_etrecord: bool = False,
100+
skip_dim_order: bool = True,
99101
):
100102
# Store necessary constructor arguments.
101103
self.model = model
@@ -118,6 +120,7 @@ def __init__(
118120
self.dynamic_shapes = dynamic_shapes
119121
self.save_exported_program = save_exported_program
120122
self.generate_etrecord = generate_etrecord
123+
self.skip_dim_order = skip_dim_order
121124

122125
# Note: treat this as the source of truth for the result of
123126
# torch.export'ing a model. If the overall ExportedProgram is needed,
@@ -197,7 +200,7 @@ def _get_dynamic_shape(self) -> Any:
197200
def _get_edge_config(self) -> EdgeCompileConfig:
198201
edge_config = EdgeCompileConfig(
199202
_check_ir_validity=False,
200-
_skip_dim_order=True,
203+
_skip_dim_order=self.skip_dim_order,
201204
)
202205
return edge_config
203206

extension/llm/export/config/llm_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -288,6 +289,7 @@ class Pt2eQuantize(str, Enum):
288289
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
289290
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
290291
vulkan_8w = "vulkan_8w"
292+
tosa_8a8w = "tosa_8a8w"
291293

292294

293295
class SpinQuant(str, Enum):
@@ -474,6 +476,16 @@ class TorchAOKernelsConfig:
474476
use_torchao_kernels_tied_embedding: bool = False
475477

476478

479+
@dataclass
480+
class TosaConfig:
481+
"""
482+
Configures the TOSA backend.
483+
"""
484+
485+
enabled: bool = False
486+
version: str = "TOSA-1.0+INT"
487+
488+
477489
@dataclass
478490
class BackendConfig:
479491
"""
@@ -488,6 +500,7 @@ class BackendConfig:
488500
mps: MPSConfig = field(default_factory=MPSConfig)
489501
openvino: OpenvinoConfig = field(default_factory=OpenvinoConfig)
490502
torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig)
503+
tosa: TosaConfig = field(default_factory=TosaConfig)
491504

492505

493506
################################################################################

extension/llm/export/partitioner_lib.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -236,3 +237,12 @@ def get_qnn_partitioner(
236237
# TODO: if deprecated legacy export, skip_mutable_buffer can be set False
237238
skip_mutable_buffer=True,
238239
)
240+
241+
242+
def get_tosa_partitioner(version: str):
243+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
244+
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
245+
246+
compile_spec = TosaCompileSpec(version)
247+
248+
return TOSAPartitioner(compile_spec)

extension/llm/export/quantizer_lib.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -320,3 +321,22 @@ def get_vulkan_quantizer(pt2e_quantize: str):
320321

321322
quantizer = VulkanQuantizer().set_global(config)
322323
return quantizer
324+
325+
326+
def get_tosa_quantizer(version: str, pt2e_quantize: str):
327+
from executorch.backends.arm.quantizer.arm_quantizer import (
328+
get_symmetric_quantization_config,
329+
TOSAQuantizer,
330+
)
331+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
332+
333+
compile_spec = TosaCompileSpec(version)
334+
335+
quantizer = TOSAQuantizer(compile_spec)
336+
337+
if pt2e_quantize == "tosa_8a8w":
338+
quantizer.set_global(get_symmetric_quantization_config())
339+
else:
340+
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")
341+
342+
return quantizer

0 commit comments

Comments
 (0)