|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 | # All rights reserved. |
3 | | -# Copyright 2025 Arm Limited and/or its affiliates. |
| 3 | +# Copyright 2025-2026 Arm Limited and/or its affiliates. |
4 | 4 | # |
5 | 5 | # This source code is licensed under the BSD-style license found in the |
6 | 6 | # LICENSE file in the root directory of this source tree. |
|
37 | 37 | get_mps_partitioner, |
38 | 38 | get_openvino_partitioner, |
39 | 39 | get_qnn_partitioner, |
| 40 | + get_tosa_partitioner, |
40 | 41 | get_vulkan_partitioner, |
41 | 42 | get_xnnpack_partitioner, |
42 | 43 | ) |
|
46 | 47 | get_pt2e_quantization_params, |
47 | 48 | get_pt2e_quantizers, |
48 | 49 | get_qnn_quantizer, |
| 50 | + get_tosa_quantizer, |
49 | 51 | get_vulkan_quantizer, |
50 | 52 | ) |
51 | 53 | from executorch.util.activation_memory_profiler import generate_memory_trace |
@@ -210,6 +212,7 @@ def build_args_parser() -> argparse.ArgumentParser: |
210 | 212 | "coreml_baseline_8a_c8w", |
211 | 213 | "coreml_baseline_8a_c4w", |
212 | 214 | "vulkan_8w", |
| 215 | + "tosa_8a8w", |
213 | 216 | ], |
214 | 217 | help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", |
215 | 218 | ) |
@@ -788,6 +791,11 @@ def get_quantizer_and_quant_params(llm_config): |
788 | 791 | llm_config.quantization.pt2e_quantize.value |
789 | 792 | ) |
790 | 793 | quantizers.append(coreml_quantizer) |
| 794 | + if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize: |
| 795 | + tosa_quantizer = get_tosa_quantizer( |
| 796 | + llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value |
| 797 | + ) |
| 798 | + quantizers.append(tosa_quantizer) |
791 | 799 | if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: |
792 | 800 | assert ( |
793 | 801 | len(quantizers) == 0 |
@@ -930,6 +938,32 @@ def _to_edge_and_lower_llama_openvino( |
930 | 938 | return builder.to_executorch(passes=additional_passes) |
931 | 939 |
|
932 | 940 |
|
| 941 | +def _to_edge_and_lower_llama_tosa( |
| 942 | + builder_exported, |
| 943 | + modelname, |
| 944 | + quantizers, |
| 945 | + additional_passes, |
| 946 | + tosa_spec, |
| 947 | + verbose: bool = False, |
| 948 | +) -> LLMEdgeManager: |
| 949 | + |
| 950 | + logging.info("Lowering model using TOSA partitioner") |
| 951 | + |
| 952 | + partitioners = [] |
| 953 | + partitioners.append(get_tosa_partitioner(tosa_spec)) |
| 954 | + |
| 955 | + modelname = f"tosa_{modelname}" |
| 956 | + |
| 957 | + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( |
| 958 | + partitioners |
| 959 | + ) |
| 960 | + |
| 961 | + if verbose: |
| 962 | + print_delegation_info(builder.edge_manager.exported_program().graph_module) |
| 963 | + |
| 964 | + return builder.to_executorch(passes=additional_passes) |
| 965 | + |
| 966 | + |
933 | 967 | def _to_edge_and_lower_llama( # noqa: C901 |
934 | 968 | builder_exported, |
935 | 969 | modelname, |
@@ -1119,7 +1153,10 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 |
1119 | 1153 | additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] |
1120 | 1154 |
|
1121 | 1155 | # export_to_edge |
1122 | | - builder_exported = _prepare_for_llama_export(llm_config).export() |
| 1156 | + builder_manager = _prepare_for_llama_export(llm_config) |
| 1157 | + if llm_config.backend.tosa.enabled: |
| 1158 | + builder_manager.skip_dim_order = False |
| 1159 | + builder_exported = builder_manager.export() |
1123 | 1160 | builder_exported.run_canonical_optimizations() |
1124 | 1161 | modelname = builder_exported.modelname |
1125 | 1162 |
|
@@ -1162,6 +1199,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 |
1162 | 1199 | openvino_device=llm_config.backend.openvino.device, |
1163 | 1200 | verbose=llm_config.debug.verbose, |
1164 | 1201 | ) |
| 1202 | + elif llm_config.backend.tosa.enabled: |
| 1203 | + builder = _to_edge_and_lower_llama_tosa( |
| 1204 | + builder_exported, |
| 1205 | + modelname, |
| 1206 | + quantizers, |
| 1207 | + additional_passes, |
| 1208 | + llm_config.backend.tosa.version, |
| 1209 | + verbose=llm_config.debug.verbose, |
| 1210 | + ) |
1165 | 1211 | else: |
1166 | 1212 | builder = _to_edge_and_lower_llama( |
1167 | 1213 | builder_exported, |
|
0 commit comments