@@ -37,6 +37,8 @@ class OpFeatures:
3737 # bool indicating if the operator has a resize function, which allows it to
3838 # support models with dynamic shape
3939 "supports_resize" ,
40+ # bool indicating if the operator supports tensors with more than 4 dimensions
41+ "supports_highdim" ,
4042 # bool indicating if the operator handles its own prepacking. If this is True,
4143 # then the insert_prepack_nodes pass will not insert prepack nodes for the args
4244 # of the op.
@@ -60,6 +62,7 @@ def __init__(
6062 Union [utils .TensorRepSet , List [utils .TensorRepSet ]]
6163 ] = None ,
6264 supports_resize : bool = False ,
65+ supports_highdim : bool = False ,
6366 supports_prepacking : bool = False ,
6467 are_node_inputs_supported_fn : Optional [Callable ] = allow_node ,
6568 pick_io_storage_fn : Optional [Callable ] = None ,
@@ -85,6 +88,7 @@ def __init__(
8588 self .outputs_storage = utils .TensorRepSetList (self .inputs_storage [0 ])
8689
8790 self .supports_resize = supports_resize
91+ self .supports_highdim = supports_highdim
8892 self .supports_prepacking = supports_prepacking
8993
9094 self .are_node_inputs_supported_fn = are_node_inputs_supported_fn
@@ -239,6 +243,7 @@ def register_binaryop_cpp_ops():
239243 inputs_storage = utils .ANY_STORAGE ,
240244 inputs_dtypes = utils .FP_INT_T ,
241245 supports_resize = True ,
246+ supports_highdim = True ,
242247 )
243248
244249
@@ -253,6 +258,7 @@ def register_pow_tensor_scalar():
253258 inputs_storage = utils .ANY_STORAGE ,
254259 inputs_dtypes = utils .FP_T ,
255260 supports_resize = True ,
261+ supports_highdim = True ,
256262 )
257263
258264
@@ -630,6 +636,7 @@ def register_reduce_cpp_ops():
630636 inputs_storage = utils .ANY_TEXTURE ,
631637 inputs_dtypes = utils .FP_T ,
632638 supports_resize = True ,
639+ supports_highdim = True ,
633640 are_node_inputs_supported_fn = is_reduce_node_supported ,
634641 pick_io_storage_fn = pick_storage_for_reduce ,
635642 )
@@ -651,6 +658,7 @@ def register_argreduce_cpp_ops():
651658 inputs_storage = utils .ANY_TEXTURE ,
652659 inputs_dtypes = utils .FP_T ,
653660 supports_resize = True ,
661+ supports_highdim = True ,
654662 are_node_inputs_supported_fn = is_reduce_node_supported ,
655663 pick_io_storage_fn = pick_storage_for_reduce ,
656664 )
@@ -811,6 +819,7 @@ def register_apply_rotary_emb():
811819 inputs_storage = utils .CONTIGUOUS_ANY ,
812820 inputs_dtypes = utils .FP_T ,
813821 supports_resize = True ,
822+ supports_highdim = True ,
814823 )
815824
816825
@@ -834,6 +843,7 @@ def register_permute_copy():
834843 inputs_storage = utils .ANY_STORAGE ,
835844 inputs_dtypes = utils .FP_INT_BOOL_T ,
836845 supports_resize = True ,
846+ supports_highdim = True ,
837847 )
838848
839849
@@ -848,6 +858,7 @@ def register_view_copy():
848858 inputs_storage = utils .ANY_STORAGE ,
849859 inputs_dtypes = utils .FP_INT_BOOL_T ,
850860 supports_resize = True ,
861+ supports_highdim = True ,
851862 )
852863
853864
@@ -857,6 +868,7 @@ def register_to_dim_order_copy():
857868 inputs_storage = utils .ANY_BUFFER ,
858869 inputs_dtypes = utils .FP_INT_BOOL_T ,
859870 supports_resize = True ,
871+ supports_highdim = True ,
860872 )
861873
862874
@@ -871,6 +883,7 @@ def register_squeeze_copy():
871883 inputs_storage = utils .ANY_STORAGE ,
872884 inputs_dtypes = utils .FP_INT_BOOL_T ,
873885 supports_resize = True ,
886+ supports_highdim = True ,
874887 )
875888
876889
@@ -885,6 +898,7 @@ def register_unsqueeze_copy():
885898 inputs_storage = utils .ANY_STORAGE ,
886899 inputs_dtypes = utils .FP_INT_BOOL_T ,
887900 supports_resize = True ,
901+ supports_highdim = True ,
888902 )
889903
890904
@@ -899,6 +913,7 @@ def register_clone():
899913 inputs_storage = utils .ANY_STORAGE ,
900914 inputs_dtypes = utils .FP_INT_BOOL_T ,
901915 supports_resize = True ,
916+ supports_highdim = True ,
902917 )
903918
904919
@@ -908,6 +923,7 @@ def register_clone_dim_order():
908923 inputs_storage = utils .ANY_STORAGE ,
909924 inputs_dtypes = utils .FP_INT_BOOL_T ,
910925 supports_resize = True ,
926+ supports_highdim = True ,
911927 )
912928
913929
@@ -922,6 +938,7 @@ def register_gather():
922938 inputs_storage = utils .ANY_STORAGE ,
923939 inputs_dtypes = utils .FP_INT_BOOL_T ,
924940 supports_resize = True ,
941+ supports_highdim = True ,
925942 )
926943
927944
@@ -936,6 +953,7 @@ def register_expand_copy():
936953 inputs_storage = utils .ANY_BUFFER ,
937954 inputs_dtypes = utils .FP_INT_BOOL_T ,
938955 supports_resize = False ,
956+ supports_highdim = True ,
939957 )
940958
941959
@@ -964,6 +982,7 @@ def register_select_copy():
964982 inputs_storage = utils .ANY_STORAGE ,
965983 inputs_dtypes = utils .FP_INT_BOOL_T ,
966984 supports_resize = True ,
985+ supports_highdim = True ,
967986 )
968987
969988
@@ -978,6 +997,7 @@ def register_slice_copy():
978997 inputs_storage = utils .ANY_STORAGE ,
979998 inputs_dtypes = utils .FP_INT_BOOL_T ,
980999 supports_resize = True ,
1000+ supports_highdim = True ,
9811001 )
9821002
9831003
@@ -992,6 +1012,7 @@ def register_split_with_sizes_copy():
9921012 inputs_storage = utils .ANY_STORAGE ,
9931013 inputs_dtypes = utils .FP_INT_BOOL_T ,
9941014 supports_resize = True ,
1015+ supports_highdim = True ,
9951016 )
9961017
9971018
0 commit comments