Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip
ReplaceInfAndLimitValuesPass,
)
from .control_flow_const_inline import ( # noqa # usort: skip
ControlFlowConstInlinePass,
)
from .arm_pass_manager import ArmPassManager # noqa # usort: skip
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CastInt64BuffersToInt32Pass,
CastToInt32Pass,
ComputeConstantOpsAOTPass,
ControlFlowConstInlinePass,
Conv1dUnsqueezePass,
ConvertELUParamsPass,
ConvertExpandCopyToRepeatPass,
Expand Down Expand Up @@ -121,6 +122,7 @@
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.pipeline_config import (
Expand Down Expand Up @@ -240,6 +242,7 @@ def _tosa_pipeline(
DecomposeVarPass(),
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
ConvertELUParamsPass(),
ControlFlowConstInlinePass(),
NormalizeWhileInitialArgsPass(use_exir_clone=True),
]
)
Expand Down Expand Up @@ -416,6 +419,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
[
ReplaceScalarWithTensorByProfilePass(tfa_pass=True),
ScalarsToAttributePass(tfa_pass=True),
ControlFlowConstInlinePass(tfa_pass=True),
]
)

Expand Down
58 changes: 58 additions & 0 deletions backends/arm/_passes/control_flow_const_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.transforms.utils import is_get_attr_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_cond_while_submodules

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule


class ControlFlowConstInlinePass(ArmPass):
"""
When we lift out each control flow body as its own GraphModule, any scalar constants that were captured in Python become module attributes. FX represents those as get_attr nodes in the
submodule graph. These become getattr nodes submodule graph.

This pass ensures that Scalar tensors in control flow operation are converted from getattr operators to expected call_function full ops.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def call(self, graph_module: GraphModule) -> PassResult:
modified = False

for _, submodule, _ in get_cond_while_submodules(graph_module):
for submodule_node in submodule.graph.nodes:
if is_get_attr_node(submodule_node):
val = getattr(
submodule_node.graph.owning_module, submodule_node.target
)
with submodule.graph.inserting_before(submodule_node):
const_node = submodule.graph.create_node(
op="call_function",
target=exir_ops.edge.aten.full.default,
args=(val.shape, val.item()),
kwargs={
"device": submodule_node.meta["val"].device,
"dtype": submodule_node.meta["val"].dtype,
},
)
const_node.meta = submodule_node.meta
submodule_node.replace_all_uses_with(const_node)
submodule.graph.erase_node(submodule_node)
modified = True

if modified:
graph_module.recompile()
graph_module.graph.eliminate_dead_code()

return PassResult(graph_module, modified)
104 changes: 62 additions & 42 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,6 +10,7 @@
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.exir.graph_module import get_cond_while_submodules

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node
Expand All @@ -37,49 +37,53 @@ class ScalarsToAttributePass(ArmPass):
torch.ops.aten.div_.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
def _convert_scalar_args(
self,
graph_module: GraphModule,
n: Node,
) -> None:
"""
Convert scalar literal args of targeted_ops in node n of graph_module
into attribute get_attr nodes with registered buffers.
"""
if n.op != "call_function" or n.target not in self.targeted_ops:
return

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

output_fake_tensor = get_first_fake_tensor(n)
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(output_fake_tensor):
new_args.append(arg)
continue

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

output_fake_tensor = get_first_fake_tensor(n)
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(
output_fake_tensor
):
new_args.append(arg)
continue

prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg)),
device=output_fake_tensor.device,
dtype=output_fake_tensor.dtype,
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg)),
device=output_fake_tensor.device,
dtype=output_fake_tensor.dtype,
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
n.args = tuple(new_args)

# Replace rsub.Scalar with sub.Tensor as retracing will fail otherwise
Expand All @@ -93,6 +97,22 @@ def call(self, graph_module: GraphModule) -> PassResult:
sub.meta["val"] = n.meta["val"]
graph_module.graph.erase_node(n)

def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None:
"""
Apply scalar argument conversion on subgraphs of control-flow nodes.
"""
for _, submodule, _ in get_cond_while_submodules(graph_module):
for submodule_node in submodule.graph.nodes:
# use aten.full.default for scalar constants in control subgraphs
self._convert_scalar_args(submodule, submodule_node)
graph_module.recompile()

def call(self, graph_module: GraphModule) -> PassResult:
# convert scalars in control-flow subgraphs and main graph
for node in list(graph_module.graph.nodes):
n = cast(Node, node)
self.handle_control_nodes(n, graph_module)
self._convert_scalar_args(graph_module, n)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
6 changes: 2 additions & 4 deletions backends/arm/test/ops/test_cond.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -227,7 +227,6 @@ def _set_branch_calibration_samples(
"case",
test_cases,
xfails={
"one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.",
"nested_one_arg_one_output": "Not fully delegated.",
},
)
Expand All @@ -251,7 +250,6 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]):
"case",
test_cases,
xfails={
"one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.",
"nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0",
},
)
Expand Down Expand Up @@ -287,13 +285,13 @@ def test_cond_u55_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]):
"case",
test_cases,
xfails={
"one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.",
"nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0",
},
skips={
"one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
"one_arg_const_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
"multiple_one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
"one_arg_and_scalar_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.",
},
)
@common.XfailIfNoCorstone320.with_args(raises=None)
Expand Down
Loading