Skip to content

Commit ed9f316

Browse files
committed
NXP backend: Enable aten.softmax.default delegation to Neutron.
1 parent 08282a4 commit ed9f316

File tree

3 files changed

+253
-79
lines changed

3 files changed

+253
-79
lines changed
Lines changed: 103 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import numpy as np
7+
68
from executorch.backends.nxp.backend.custom_delegation_options import (
79
CustomDelegationOptions,
810
)
9-
from executorch.backends.nxp.backend.edge_helper import input_rank
11+
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
1012
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
1113
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1214
softmax_options,
@@ -17,46 +19,129 @@
1719

1820

1921
class SoftmaxConverter(NodeConverter):
22+
23+
@staticmethod
24+
def _get_channels_dim(node: Node) -> int:
25+
"""Get the dimension index for channels, based on data format.
26+
:return: 1 for the channels_first format (NCHW), rank-1 for the channels_last format (NHWC).
27+
"""
28+
rank = len(node.meta["val"].shape)
29+
return 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else rank - 1
30+
31+
@staticmethod
32+
def _get_spatial_dims(node: Node) -> list[int]:
33+
"""Extract spatial dimensions from the node's input shape.
34+
Returns a list with [N, H, W] (or equivalent for other ranks).
35+
"""
36+
input_shape = list(node.meta["val"].shape)
37+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
38+
# NCHW: skip the channel dimension at index 1
39+
return [input_shape[0]] + input_shape[2:]
40+
else:
41+
# NHWC: skip the last dimension
42+
return input_shape[:-1]
43+
44+
@staticmethod
45+
def _get_total_spatial_size(node: Node) -> int:
46+
"""Calculate total spatial size (product of all spatial dimensions)."""
47+
return int(np.prod(SoftmaxConverter._get_spatial_dims(node)))
48+
49+
@staticmethod
50+
def _get_channels(node: Node) -> int:
51+
"""Get the number of channels from the node's input shape."""
52+
return node.meta["val"].shape[SoftmaxConverter._get_channels_dim(node)]
53+
2054
@staticmethod
2155
def _is_supported_on_target(
2256
node: Node,
2357
neutron_target_spec: NeutronTargetSpec,
2458
parameters_mapping: dict[str, Parameter],
2559
custom_delegation_options: CustomDelegationOptions,
2660
) -> bool:
27-
return False
61+
"""Check if the softmax operation can be executed on Neutron hardware.
62+
63+
Hardware constraints:
64+
1. Input rank must be >= 2 (Neutron does not support 1D)
65+
2. Channels must be a multiple of num_macs
66+
3. Channels < 4096 / num_pipes * 4
67+
4. Total spatial size (N*H*W) <= 4096
68+
5. (channels * spatial_size) / num_macs <= 65536
69+
"""
70+
input_shape = node.meta["val"].shape
71+
72+
# Constraint 1: Neutron does not support 1D SoftMax
73+
if len(input_shape) == 1:
74+
return False
75+
76+
num_macs = neutron_target_spec.get_num_macs()
77+
num_pipes = neutron_target_spec.get_num_pipes()
78+
channels = SoftmaxConverter._get_channels(node)
79+
total_spatial_size = SoftmaxConverter._get_total_spatial_size(node)
80+
81+
# Constraint 2: Channels must be a multiple of num_macs
82+
if channels % num_macs != 0:
83+
return False
84+
85+
# Constraint 3: Channel size limit
86+
if channels >= 4096 / num_pipes * 4:
87+
return False
88+
89+
# Constraint 4: Spatial size limit
90+
if total_spatial_size > 4096:
91+
return False
92+
93+
# Constraint 5: Total processing size limit
94+
if channels * total_spatial_size / num_macs > 65536:
95+
return False
96+
97+
return True
98+
99+
@staticmethod
100+
def _normalize_dim(dim: int, rank: int) -> int:
101+
"""Make sure the dimension index `dim` is positive.
102+
:arg dim: The dimension index (can be negative)
103+
:arg rank: The total number of dimensions
104+
105+
:return: Positive dimension index
106+
"""
107+
return dim % rank
28108

29109
@staticmethod
30110
def _is_supported_in_IR(
31111
node: Node,
32112
parameters_mapping: dict[str, Parameter],
33113
custom_delegation_options: CustomDelegationOptions,
34114
) -> bool:
35-
# The IR only supports the `dim` as the last dimension. But that depends on the format of the input tensor,
36-
# which is only known after the `Partitioner` has divided the model. So if the input shape can be channels
37-
# first (i.e. is more than 2D), we cannot determine IR support (we assume it's not supported).
38-
x_rank = input_rank(node, 0)
39-
if x_rank > 2:
115+
"""Check if the softmax operation is supported in NeutronIR.
116+
NeutronIR only supports softmax along the channels dimension.
117+
"""
118+
dim = SoftmaxConverter._normalize_dim(node.args[1], len(node.meta["val"].shape))
119+
120+
# NeutronIR only supports the `dim` as the channels dimension
121+
channels_dim = SoftmaxConverter._get_channels_dim(node)
122+
if dim != channels_dim:
40123
return False
41124

42-
dim = SoftmaxConverter._normalize_dim(node.args[1], x_rank)
43-
if dim != x_rank - 1:
125+
half_to_float = node.args[2] if len(node.args) > 2 else False
126+
if half_to_float:
127+
# This argument states that the Softmax has a float16 input and output, but the computation is done in
128+
# float32. An equivalent in NeutronIR would require explicit casting, which is currently not implemented.
44129
return False
45130

46131
return True
47132

48-
@staticmethod
49-
def _normalize_dim(dim, rank):
50-
# convert negative index to positive
51-
if dim < 0:
52-
dim += rank
53-
return dim
54-
55133
def convert(self, node: Node):
134+
"""Convert `aten._softmax.default` node to NeutronIR.
135+
The schema is:
136+
aten::_softmax(
137+
Tensor self,
138+
int dim,
139+
bool half_to_float
140+
) -> Tensor
141+
"""
56142
self.assert_convertible(node)
57143

58144
t_op = self._create_tflite_op_with_io_tensors(node)
59-
60145
t_op.builtin_options = softmax_options.Softmax(beta=1.0)
61146

62147
self.builder.append_operators([t_op])

0 commit comments

Comments
 (0)