Skip to content

Commit 785adf0

Browse files
ssjiaSS-JIA
authored andcommitted
Back out "[Diff Train][pytorch/executorch] Apply fixup patch to fbsource"
Pull Request resolved: #17399 Revert D92897428 which accidentally erased from recent changes to the Executorch Vulkan backend when syncing open source and fbsource. ghstack-source-id: 340490031 @exported-using-ghexport Differential Revision: [D93012332](https://our.internmc.facebook.com/intern/diff/D93012332/)
1 parent ce92183 commit 785adf0

File tree

5 files changed

+56
-0
lines changed

5 files changed

+56
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class OpFeatures:
3737
# bool indicating if the operator has a resize function, which allows it to
3838
# support models with dynamic shape
3939
"supports_resize",
40+
# bool indicating if the operator supports tensors with more than 4 dimensions
41+
"supports_highdim",
4042
# bool indicating if the operator handles its own prepacking. If this is True,
4143
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
4244
# of the op.
@@ -60,6 +62,7 @@ def __init__(
6062
Union[utils.TensorRepSet, List[utils.TensorRepSet]]
6163
] = None,
6264
supports_resize: bool = False,
65+
supports_highdim: bool = False,
6366
supports_prepacking: bool = False,
6467
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
6568
pick_io_storage_fn: Optional[Callable] = None,
@@ -85,6 +88,7 @@ def __init__(
8588
self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0])
8689

8790
self.supports_resize = supports_resize
91+
self.supports_highdim = supports_highdim
8892
self.supports_prepacking = supports_prepacking
8993

9094
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
@@ -239,6 +243,7 @@ def register_binaryop_cpp_ops():
239243
inputs_storage=utils.ANY_STORAGE,
240244
inputs_dtypes=utils.FP_INT_T,
241245
supports_resize=True,
246+
supports_highdim=True,
242247
)
243248

244249

@@ -253,6 +258,7 @@ def register_pow_tensor_scalar():
253258
inputs_storage=utils.ANY_STORAGE,
254259
inputs_dtypes=utils.FP_T,
255260
supports_resize=True,
261+
supports_highdim=True,
256262
)
257263

258264

@@ -630,6 +636,7 @@ def register_reduce_cpp_ops():
630636
inputs_storage=utils.ANY_TEXTURE,
631637
inputs_dtypes=utils.FP_T,
632638
supports_resize=True,
639+
supports_highdim=True,
633640
are_node_inputs_supported_fn=is_reduce_node_supported,
634641
pick_io_storage_fn=pick_storage_for_reduce,
635642
)
@@ -651,6 +658,7 @@ def register_argreduce_cpp_ops():
651658
inputs_storage=utils.ANY_TEXTURE,
652659
inputs_dtypes=utils.FP_T,
653660
supports_resize=True,
661+
supports_highdim=True,
654662
are_node_inputs_supported_fn=is_reduce_node_supported,
655663
pick_io_storage_fn=pick_storage_for_reduce,
656664
)
@@ -811,6 +819,7 @@ def register_apply_rotary_emb():
811819
inputs_storage=utils.CONTIGUOUS_ANY,
812820
inputs_dtypes=utils.FP_T,
813821
supports_resize=True,
822+
supports_highdim=True,
814823
)
815824

816825

@@ -834,6 +843,7 @@ def register_permute_copy():
834843
inputs_storage=utils.ANY_STORAGE,
835844
inputs_dtypes=utils.FP_INT_BOOL_T,
836845
supports_resize=True,
846+
supports_highdim=True,
837847
)
838848

839849

@@ -848,6 +858,7 @@ def register_view_copy():
848858
inputs_storage=utils.ANY_STORAGE,
849859
inputs_dtypes=utils.FP_INT_BOOL_T,
850860
supports_resize=True,
861+
supports_highdim=True,
851862
)
852863

853864

@@ -857,6 +868,7 @@ def register_to_dim_order_copy():
857868
inputs_storage=utils.ANY_BUFFER,
858869
inputs_dtypes=utils.FP_INT_BOOL_T,
859870
supports_resize=True,
871+
supports_highdim=True,
860872
)
861873

862874

@@ -871,6 +883,7 @@ def register_squeeze_copy():
871883
inputs_storage=utils.ANY_STORAGE,
872884
inputs_dtypes=utils.FP_INT_BOOL_T,
873885
supports_resize=True,
886+
supports_highdim=True,
874887
)
875888

876889

@@ -885,6 +898,7 @@ def register_unsqueeze_copy():
885898
inputs_storage=utils.ANY_STORAGE,
886899
inputs_dtypes=utils.FP_INT_BOOL_T,
887900
supports_resize=True,
901+
supports_highdim=True,
888902
)
889903

890904

@@ -899,6 +913,7 @@ def register_clone():
899913
inputs_storage=utils.ANY_STORAGE,
900914
inputs_dtypes=utils.FP_INT_BOOL_T,
901915
supports_resize=True,
916+
supports_highdim=True,
902917
)
903918

904919

@@ -908,6 +923,7 @@ def register_clone_dim_order():
908923
inputs_storage=utils.ANY_STORAGE,
909924
inputs_dtypes=utils.FP_INT_BOOL_T,
910925
supports_resize=True,
926+
supports_highdim=True,
911927
)
912928

913929

@@ -922,6 +938,7 @@ def register_gather():
922938
inputs_storage=utils.ANY_STORAGE,
923939
inputs_dtypes=utils.FP_INT_BOOL_T,
924940
supports_resize=True,
941+
supports_highdim=True,
925942
)
926943

927944

@@ -936,6 +953,7 @@ def register_expand_copy():
936953
inputs_storage=utils.ANY_BUFFER,
937954
inputs_dtypes=utils.FP_INT_BOOL_T,
938955
supports_resize=False,
956+
supports_highdim=True,
939957
)
940958

941959

@@ -964,6 +982,7 @@ def register_select_copy():
964982
inputs_storage=utils.ANY_STORAGE,
965983
inputs_dtypes=utils.FP_INT_BOOL_T,
966984
supports_resize=True,
985+
supports_highdim=True,
967986
)
968987

969988

@@ -978,6 +997,7 @@ def register_slice_copy():
978997
inputs_storage=utils.ANY_STORAGE,
979998
inputs_dtypes=utils.FP_INT_BOOL_T,
980999
supports_resize=True,
1000+
supports_highdim=True,
9811001
)
9821002

9831003

@@ -992,6 +1012,7 @@ def register_split_with_sizes_copy():
9921012
inputs_storage=utils.ANY_STORAGE,
9931013
inputs_dtypes=utils.FP_INT_BOOL_T,
9941014
supports_resize=True,
1015+
supports_highdim=True,
9951016
)
9961017

9971018

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
266266
self.log_skip(node, "op args not supported")
267267
return False
268268

269+
if not features.supports_highdim and utils.op_contains_high_dim_tensor(node):
270+
self.log_skip(node, "op does not support high dim tensors")
271+
return False
272+
269273
if self.require_dynamic_shapes and not features.supports_resize:
270274
self.log_skip(node, "no dynamic shape support")
271275
return False

backends/vulkan/runtime/api/containers/StagingBuffer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ void StagingBuffer::cast_half_to_float_and_copy_from(
159159
for (size_t i = 0; i < numel; ++i) {
160160
dst[i] = half_to_float(src[i]);
161161
}
162+
vmaFlushAllocation(
163+
vulkan_buffer_.vma_allocator(),
164+
vulkan_buffer_.allocation(),
165+
0u,
166+
VK_WHOLE_SIZE);
162167
}
163168

164169
void StagingBuffer::cast_float_to_half_and_copy_to(

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ class StagingBuffer final {
8888
for (size_t i = 0; i < numel; ++i) {
8989
dst[i] = static_cast<DST_T>(src[i]);
9090
}
91+
vmaFlushAllocation(
92+
vulkan_buffer_.vma_allocator(),
93+
vulkan_buffer_.allocation(),
94+
0u,
95+
VK_WHOLE_SIZE);
9196
}
9297

9398
void cast_half_to_float_and_copy_from(
@@ -109,6 +114,11 @@ class StagingBuffer final {
109114
template <typename SRC_T, typename DST_T>
110115
void cast_and_copy_to(DST_T* dst, const size_t numel) {
111116
VK_CHECK_COND(numel <= this->numel());
117+
vmaInvalidateAllocation(
118+
vulkan_buffer_.vma_allocator(),
119+
vulkan_buffer_.allocation(),
120+
0u,
121+
VK_WHOLE_SIZE);
112122
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
113123
for (size_t i = 0; i < numel; ++i) {
114124
dst[i] = static_cast<DST_T>(src[i]);

backends/vulkan/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,22 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
468468
return False
469469

470470

471+
def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool:
472+
"""
473+
Returns true if the operator used to compute the given node contains a tensor
474+
with more than 4 dimensions
475+
"""
476+
if is_tensor_node(node) and tensor_node_is_high_dim(node):
477+
return True
478+
479+
for arg_node in node.args:
480+
# pyre-ignore[6]
481+
if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node):
482+
return True
483+
484+
return False
485+
486+
471487
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
472488
primary_arg_idx: Optional[int] = None
473489
for i, arg_node in enumerate(node.args):

0 commit comments

Comments
 (0)