Skip to content

Commit e232519

Browse files
authored
Make portable recipe export into one pte with multiple methods (#171)
1 parent aa7831c commit e232519

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

optimum/exporters/executorch/recipes/portable.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
1615
from typing import Dict, Union
1716

1817
from torch.export import ExportedProgram
@@ -58,24 +57,22 @@ def _lower_to_executorch(
5857
exported_programs: Dict[str, ExportedProgram],
5958
metadata=None,
6059
) -> Dict[str, ExecutorchProgram]:
61-
et_progs = {}
60+
# If just one exported program, the method name in the .pte for it should be "forward".
61+
if len(exported_programs) == 1:
62+
exported_programs = {"forward": next(iter(exported_programs.values()))}
6263

63-
for pte_name, exported_program in exported_programs.items():
64-
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
65-
et_progs[pte_name] = to_edge_transform_and_lower(
66-
exported_program,
67-
partitioner=[],
68-
compile_config=EdgeCompileConfig(
69-
_check_ir_validity=False,
70-
_skip_dim_order=True,
71-
),
72-
constant_methods=metadata,
73-
transform_passes=[RemovePaddingIdxEmbeddingPass()],
74-
).to_executorch()
75-
logging.debug(
76-
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
77-
)
78-
return et_progs
64+
et_prog = to_edge_transform_and_lower(
65+
exported_programs,
66+
partitioner=[],
67+
compile_config=EdgeCompileConfig(
68+
_check_ir_validity=False,
69+
_skip_dim_order=True,
70+
),
71+
constant_methods=metadata,
72+
transform_passes=[RemovePaddingIdxEmbeddingPass()],
73+
).to_executorch()
74+
pte_name = "model"
75+
return {pte_name: et_prog}
7976

8077
exported_progs = model.export()
8178

0 commit comments

Comments
 (0)