|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import logging |
16 | 15 | from typing import Dict, Union |
17 | 16 |
|
18 | 17 | from torch.export import ExportedProgram |
@@ -58,24 +57,22 @@ def _lower_to_executorch( |
58 | 57 | exported_programs: Dict[str, ExportedProgram], |
59 | 58 | metadata=None, |
60 | 59 | ) -> Dict[str, ExecutorchProgram]: |
61 | | - et_progs = {} |
| 60 | + # If just one exported program, the method name in the .pte for it should be "forward". |
| 61 | + if len(exported_programs) == 1: |
| 62 | + exported_programs = {"forward": next(iter(exported_programs.values()))} |
62 | 63 |
|
63 | | - for pte_name, exported_program in exported_programs.items(): |
64 | | - logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}") |
65 | | - et_progs[pte_name] = to_edge_transform_and_lower( |
66 | | - exported_program, |
67 | | - partitioner=[], |
68 | | - compile_config=EdgeCompileConfig( |
69 | | - _check_ir_validity=False, |
70 | | - _skip_dim_order=True, |
71 | | - ), |
72 | | - constant_methods=metadata, |
73 | | - transform_passes=[RemovePaddingIdxEmbeddingPass()], |
74 | | - ).to_executorch() |
75 | | - logging.debug( |
76 | | - f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}" |
77 | | - ) |
78 | | - return et_progs |
| 64 | + et_prog = to_edge_transform_and_lower( |
| 65 | + exported_programs, |
| 66 | + partitioner=[], |
| 67 | + compile_config=EdgeCompileConfig( |
| 68 | + _check_ir_validity=False, |
| 69 | + _skip_dim_order=True, |
| 70 | + ), |
| 71 | + constant_methods=metadata, |
| 72 | + transform_passes=[RemovePaddingIdxEmbeddingPass()], |
| 73 | + ).to_executorch() |
| 74 | + pte_name = "model" |
| 75 | + return {pte_name: et_prog} |
79 | 76 |
|
80 | 77 | exported_progs = model.export() |
81 | 78 |
|
|
0 commit comments