Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip
ReplaceInfAndLimitValuesPass,
)
from .control_flow_const_inline import ( # noqa # usort: skip
ControlFlowConstInlinePass,
)
from .arm_pass_manager import ArmPassManager # noqa # usort: skip
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CastInt64BuffersToInt32Pass,
CastToInt32Pass,
ComputeConstantOpsAOTPass,
ControlFlowConstInlinePass,
Conv1dUnsqueezePass,
ConvertELUParamsPass,
ConvertExpandCopyToRepeatPass,
Expand Down Expand Up @@ -121,6 +122,7 @@
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.pipeline_config import (
Expand Down Expand Up @@ -240,6 +242,7 @@ def _tosa_pipeline(
DecomposeVarPass(),
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
ConvertELUParamsPass(),
ControlFlowConstInlinePass(),
NormalizeWhileInitialArgsPass(use_exir_clone=True),
]
)
Expand Down Expand Up @@ -416,6 +419,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
[
ReplaceScalarWithTensorByProfilePass(tfa_pass=True),
ScalarsToAttributePass(tfa_pass=True),
ControlFlowConstInlinePass(tfa_pass=True),
]
)

Expand Down
58 changes: 58 additions & 0 deletions backends/arm/_passes/control_flow_const_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.transforms.utils import is_get_attr_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_cond_while_submodules

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule


class ControlFlowConstInlinePass(ArmPass):
"""
When we lift out each control flow body as its own GraphModule, any scalar constants that were captured in Python become module attributes. FX represents those as get_attr nodes in the
submodule graph. These become getattr nodes submodule graph.

This pass ensures that Scalar tensors in control flow operation are converted from getattr operators to expected call_function full ops.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def call(self, graph_module: GraphModule) -> PassResult:
modified = False

for _, submodule, _ in get_cond_while_submodules(graph_module):
for submodule_node in submodule.graph.nodes:
if is_get_attr_node(submodule_node):
val = getattr(
submodule_node.graph.owning_module, submodule_node.target
)
with submodule.graph.inserting_before(submodule_node):
const_node = submodule.graph.create_node(
op="call_function",
target=exir_ops.edge.aten.full.default,
args=(val.shape, val.item()),
kwargs={
"device": submodule_node.meta["val"].device,
"dtype": submodule_node.meta["val"].dtype,
},
)
const_node.meta = submodule_node.meta
submodule_node.replace_all_uses_with(const_node)
submodule.graph.erase_node(submodule_node)
modified = True

if modified:
graph_module.recompile()
graph_module.graph.eliminate_dead_code()

return PassResult(graph_module, modified)
104 changes: 62 additions & 42 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,6 +10,7 @@
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.exir.graph_module import get_cond_while_submodules

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node
Expand All @@ -37,49 +37,53 @@ class ScalarsToAttributePass(ArmPass):
torch.ops.aten.div_.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
def _convert_scalar_args(
self,
graph_module: GraphModule,
n: Node,
) -> None:
"""
Convert scalar literal args of targeted_ops in node n of graph_module
into attribute get_attr nodes with registered buffers.
"""
if n.op != "call_function" or n.target not in self.targeted_ops:
return

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

output_fake_tensor = get_first_fake_tensor(n)
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(output_fake_tensor):
new_args.append(arg)
continue

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

output_fake_tensor = get_first_fake_tensor(n)
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(
output_fake_tensor
):
new_args.append(arg)
continue

prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg)),
device=output_fake_tensor.device,
dtype=output_fake_tensor.dtype,
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg)),
device=output_fake_tensor.device,
dtype=output_fake_tensor.dtype,
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
n.args = tuple(new_args)

# Replace rsub.Scalar with sub.Tensor as retracing will fail otherwise
Expand All @@ -93,6 +97,22 @@ def call(self, graph_module: GraphModule) -> PassResult:
sub.meta["val"] = n.meta["val"]
graph_module.graph.erase_node(n)

def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None:
"""
Apply scalar argument conversion on subgraphs of control-flow nodes.
"""
for _, submodule, _ in get_cond_while_submodules(graph_module):
for submodule_node in submodule.graph.nodes:
# use aten.full.default for scalar constants in control subgraphs
self._convert_scalar_args(submodule, submodule_node)
graph_module.recompile()

def call(self, graph_module: GraphModule) -> PassResult:
# convert scalars in control-flow subgraphs and main graph
for node in list(graph_module.graph.nodes):
n = cast(Node, node)
self.handle_control_nodes(n, graph_module)
self._convert_scalar_args(graph_module, n)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
6 changes: 2 additions & 4 deletions backends/arm/test/ops/test_cond.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -227,7 +227,6 @@ def _set_branch_calibration_samples(
"case",
test_cases,
xfails={
"one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.",
"nested_one_arg_one_output": "Not fully delegated.",
},
)
Expand All @@ -251,7 +250,6 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]):
"case",
test_cases,
xfails={
"one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.",
"nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0",
},
)
Expand Down Expand Up @@ -287,13 +285,13 @@ def test_cond_u55_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]):
"case",
test_cases,
xfails={
"one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.",
"nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0",
},
skips={
"one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
"one_arg_const_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
"multiple_one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
"one_arg_and_scalar_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
},
)
@common.XfailIfNoCorstone320.with_args(raises=None)
Expand Down
Loading