Skip to content

Commit c09db81

Browse files
authored
Add to_out_var pass variant that support cadence custom inplace ops.
Differential Revision: D93004811 Pull Request resolved: #17406
1 parent 3f16e5b commit c09db81

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

backends/cadence/aot/BUCK

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,20 @@ fbcode_target(_kind = runtime.python_library,
187187
],
188188
)
189189

190+
fbcode_target(_kind = runtime.python_library,
191+
name = "to_out_var_pass",
192+
srcs = [
193+
"to_out_var_pass.py",
194+
],
195+
typing = True,
196+
deps = [
197+
":ops_registrations",
198+
"//caffe2:torch",
199+
"//executorch/exir/dialects:lib",
200+
"//executorch/exir/passes:lib",
201+
],
202+
)
203+
190204
fbcode_target(_kind = runtime.python_library,
191205
name = "graph_builder",
192206
srcs = [
@@ -649,3 +663,20 @@ fbcode_target(_kind = python_unittest,
649663
"//pytorch/ao:torchao",
650664
],
651665
)
666+
667+
fbcode_target(_kind = python_unittest,
668+
name = "test_to_out_var_pass",
669+
srcs = [
670+
"tests/test_to_out_var_pass.py",
671+
],
672+
typing = True,
673+
deps = [
674+
":ops_registrations",
675+
":program_builder",
676+
":to_out_var_pass",
677+
"//caffe2:torch",
678+
"//executorch/exir:lib",
679+
"//executorch/exir/dialects:lib",
680+
"//later:lib",
681+
],
682+
)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import executorch.backends.cadence.aot.ops_registrations # noqa
10+
import torch
11+
from executorch.backends.cadence.aot.program_builder import ProgramBuilder
12+
from executorch.backends.cadence.aot.to_out_var_pass import CadenceToOutVarPass
13+
from executorch.exir import ExecutorchBackendConfig
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from later.unittest import TestCase
16+
17+
18+
class TestCadenceToOutVarPass(TestCase):
19+
def test_serialize_with_slice_scatter_inplace(self) -> None:
20+
"""Test that a graph with slice_scatter_ can be serialized after CadenceToOutVarPass."""
21+
builder = ProgramBuilder()
22+
# Create input tensor placeholder
23+
x = builder.placeholder("x", torch.randn(10, dtype=torch.float32))
24+
# Create source tensor placeholder
25+
src = builder.placeholder("src", torch.randn(3, dtype=torch.float32))
26+
27+
# Call slice_scatter_ inplace op
28+
result = builder.call_operator(
29+
op=exir_ops.edge.cadence.slice_scatter_.default,
30+
args=(x, src, 0, 2, 5, 1),
31+
)
32+
builder.output([result])
33+
34+
# Get the edge program
35+
edge_program = builder.get_edge_program()
36+
37+
# Apply CadenceToOutVarPass and serialize
38+
exec_program = edge_program.to_executorch(
39+
ExecutorchBackendConfig(
40+
to_out_var_pass=CadenceToOutVarPass(),
41+
)
42+
)
43+
44+
# Verify serialization succeeded by checking the buffer is non-empty
45+
buffer = exec_program.buffer
46+
self.assertIsNotNone(buffer)
47+
self.assertGreater(len(buffer), 0)
48+
49+
def test_serialize_with_mixed_ops(self) -> None:
50+
"""Test that a graph with mixed ops including slice_scatter_ can be serialized."""
51+
builder = ProgramBuilder()
52+
x = builder.placeholder("x", torch.randn(10, dtype=torch.float32))
53+
y = builder.placeholder("y", torch.randn(10, dtype=torch.float32))
54+
src = builder.placeholder("src", torch.randn(3, dtype=torch.float32))
55+
56+
# Add operation
57+
add_result = builder.call_operator(exir_ops.edge.aten.add.Tensor, (x, y))
58+
59+
# Slice scatter inplace operation
60+
scatter_result = builder.call_operator(
61+
op=exir_ops.edge.cadence.slice_scatter_.default,
62+
args=(add_result, src, 0, 2, 5, 1),
63+
)
64+
65+
builder.output([scatter_result])
66+
67+
edge_program = builder.get_edge_program()
68+
69+
exec_program = edge_program.to_executorch(
70+
ExecutorchBackendConfig(
71+
to_out_var_pass=CadenceToOutVarPass(),
72+
)
73+
)
74+
75+
buffer = exec_program.buffer
76+
self.assertIsNotNone(buffer)
77+
self.assertGreater(len(buffer), 0)
78+
79+
def test_serialize_with_add_tensor(self) -> None:
80+
"""Test that a simple add graph can be serialized with CadenceToOutVarPass."""
81+
builder = ProgramBuilder()
82+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
83+
y = builder.placeholder("y", torch.randn(3, 5, dtype=torch.float32))
84+
85+
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, (x, y))
86+
builder.output([add])
87+
88+
edge_program = builder.get_edge_program()
89+
90+
exec_program = edge_program.to_executorch(
91+
ExecutorchBackendConfig(
92+
to_out_var_pass=CadenceToOutVarPass(),
93+
)
94+
)
95+
96+
buffer = exec_program.buffer
97+
self.assertIsNotNone(buffer)
98+
self.assertGreater(len(buffer), 0)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import executorch.backends.cadence.aot.ops_registrations # noqa
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.passes import ToOutVarPass
13+
from torch.fx.passes.infra.pass_base import PassResult
14+
15+
16+
class CadenceToOutVarPass(ToOutVarPass):
17+
"""Adds support for custom cadence inplace ops."""
18+
19+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
20+
for slice_scatter_inplace in graph_module.graph.find_nodes(
21+
op="call_function", target=exir_ops.edge.cadence.slice_scatter_.default
22+
):
23+
slice_scatter_inplace.target = slice_scatter_inplace.target._op
24+
return super().call(graph_module)

0 commit comments

Comments
 (0)