Skip to content

Commit 0578684

Browse files
author
ssjia
committed
Update on "[ET-VK][ez] Make q8ta_conv2d use 4C1W layout"
This changes the q8ta_conv2d and q8ta_conv2d_dw operators' input layout from PackedInt8_4W4C to PackedInt8_4C1W in the op registry. The 4C1W layout aligns with the natural output format of channel-packed convolutions, avoiding unnecessary layout conversions between consecutive conv layers. Also adds explicit `outputs_storage` declarations (PACKED_INT8_CHANNELS_PACKED_BUFFER) to both the PW and general q8ta_conv2d op registrations, ensuring the layout propagation pass can correctly determine output layouts. Differential Revision: [D93000165](https://our.internmc.facebook.com/intern/diff/D93000165/) [ghstack-poisoned]
2 parents f21bfb9 + b46c7be commit 0578684

File tree

5 files changed

+54
-0
lines changed

5 files changed

+54
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class OpFeatures:
3737
# bool indicating if the operator has a resize function, which allows it to
3838
# support models with dynamic shape
3939
"supports_resize",
40+
# bool indicating if the operator supports tensors with more than 4 dimensions
41+
"supports_highdim",
4042
# bool indicating if the operator handles its own prepacking. If this is True,
4143
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
4244
# of the op.
@@ -60,6 +62,7 @@ def __init__(
6062
Union[utils.TensorRepSet, List[utils.TensorRepSet]]
6163
] = None,
6264
supports_resize: bool = False,
65+
supports_highdim: bool = False,
6366
supports_prepacking: bool = False,
6467
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
6568
pick_io_storage_fn: Optional[Callable] = None,
@@ -85,6 +88,7 @@ def __init__(
8588
self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0])
8689

8790
self.supports_resize = supports_resize
91+
self.supports_highdim = supports_highdim
8892
self.supports_prepacking = supports_prepacking
8993

9094
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
@@ -239,6 +243,7 @@ def register_binaryop_cpp_ops():
239243
inputs_storage=utils.ANY_STORAGE,
240244
inputs_dtypes=utils.FP_INT_T,
241245
supports_resize=True,
246+
supports_highdim=True,
242247
)
243248

244249

@@ -253,6 +258,7 @@ def register_pow_tensor_scalar():
253258
inputs_storage=utils.ANY_STORAGE,
254259
inputs_dtypes=utils.FP_T,
255260
supports_resize=True,
261+
supports_highdim=True,
256262
)
257263

258264

@@ -635,6 +641,7 @@ def register_reduce_cpp_ops():
635641
inputs_storage=utils.ANY_TEXTURE,
636642
inputs_dtypes=utils.FP_T,
637643
supports_resize=True,
644+
supports_highdim=True,
638645
are_node_inputs_supported_fn=is_reduce_node_supported,
639646
pick_io_storage_fn=pick_storage_for_reduce,
640647
)
@@ -656,6 +663,7 @@ def register_argreduce_cpp_ops():
656663
inputs_storage=utils.ANY_TEXTURE,
657664
inputs_dtypes=utils.FP_T,
658665
supports_resize=True,
666+
supports_highdim=True,
659667
are_node_inputs_supported_fn=is_reduce_node_supported,
660668
pick_io_storage_fn=pick_storage_for_reduce,
661669
)
@@ -851,6 +859,7 @@ def register_apply_rotary_emb():
851859
inputs_storage=utils.CONTIGUOUS_ANY,
852860
inputs_dtypes=utils.FP_T,
853861
supports_resize=True,
862+
supports_highdim=True,
854863
)
855864

856865

@@ -874,6 +883,7 @@ def register_permute_copy():
874883
inputs_storage=utils.ANY_STORAGE,
875884
inputs_dtypes=utils.FP_INT_BOOL_T,
876885
supports_resize=True,
886+
supports_highdim=True,
877887
)
878888

879889

@@ -888,6 +898,7 @@ def register_view_copy():
888898
inputs_storage=utils.ANY_STORAGE,
889899
inputs_dtypes=utils.FP_INT_BOOL_T,
890900
supports_resize=True,
901+
supports_highdim=True,
891902
)
892903

893904

@@ -897,6 +908,7 @@ def register_to_dim_order_copy():
897908
inputs_storage=utils.ANY_BUFFER,
898909
inputs_dtypes=utils.FP_INT_BOOL_T,
899910
supports_resize=True,
911+
supports_highdim=True,
900912
)
901913

902914

@@ -911,6 +923,7 @@ def register_squeeze_copy():
911923
inputs_storage=utils.ANY_STORAGE,
912924
inputs_dtypes=utils.FP_INT_BOOL_T,
913925
supports_resize=True,
926+
supports_highdim=True,
914927
)
915928

916929

@@ -925,6 +938,7 @@ def register_unsqueeze_copy():
925938
inputs_storage=utils.ANY_STORAGE,
926939
inputs_dtypes=utils.FP_INT_BOOL_T,
927940
supports_resize=True,
941+
supports_highdim=True,
928942
)
929943

930944

@@ -939,6 +953,7 @@ def register_clone():
939953
inputs_storage=utils.ANY_STORAGE,
940954
inputs_dtypes=utils.FP_INT_BOOL_T,
941955
supports_resize=True,
956+
supports_highdim=True,
942957
)
943958

944959

@@ -978,6 +993,7 @@ def register_expand_copy():
978993
inputs_storage=utils.ANY_BUFFER,
979994
inputs_dtypes=utils.FP_INT_BOOL_T,
980995
supports_resize=False,
996+
supports_highdim=True,
981997
)
982998

983999

@@ -1006,6 +1022,7 @@ def register_select_copy():
10061022
inputs_storage=utils.ANY_STORAGE,
10071023
inputs_dtypes=utils.FP_INT_BOOL_T,
10081024
supports_resize=True,
1025+
supports_highdim=True,
10091026
)
10101027

10111028

@@ -1020,6 +1037,7 @@ def register_slice_copy():
10201037
inputs_storage=utils.ANY_STORAGE,
10211038
inputs_dtypes=utils.FP_INT_BOOL_T,
10221039
supports_resize=True,
1040+
supports_highdim=True,
10231041
)
10241042

10251043

@@ -1034,6 +1052,7 @@ def register_split_with_sizes_copy():
10341052
inputs_storage=utils.ANY_STORAGE,
10351053
inputs_dtypes=utils.FP_INT_BOOL_T,
10361054
supports_resize=True,
1055+
supports_highdim=True,
10371056
)
10381057

10391058

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
266266
self.log_skip(node, "op args not supported")
267267
return False
268268

269+
if not features.supports_highdim and utils.op_contains_high_dim_tensor(node):
270+
self.log_skip(node, "op does not support high dim tensors")
271+
return False
272+
269273
if self.require_dynamic_shapes and not features.supports_resize:
270274
self.log_skip(node, "no dynamic shape support")
271275
return False

backends/vulkan/runtime/api/containers/StagingBuffer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ void StagingBuffer::cast_half_to_float_and_copy_from(
159159
for (size_t i = 0; i < numel; ++i) {
160160
dst[i] = half_to_float(src[i]);
161161
}
162+
vmaFlushAllocation(
163+
vulkan_buffer_.vma_allocator(),
164+
vulkan_buffer_.allocation(),
165+
0u,
166+
VK_WHOLE_SIZE);
162167
}
163168

164169
void StagingBuffer::cast_float_to_half_and_copy_to(

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ class StagingBuffer final {
8888
for (size_t i = 0; i < numel; ++i) {
8989
dst[i] = static_cast<DST_T>(src[i]);
9090
}
91+
vmaFlushAllocation(
92+
vulkan_buffer_.vma_allocator(),
93+
vulkan_buffer_.allocation(),
94+
0u,
95+
VK_WHOLE_SIZE);
9196
}
9297

9398
void cast_half_to_float_and_copy_from(
@@ -109,6 +114,11 @@ class StagingBuffer final {
109114
template <typename SRC_T, typename DST_T>
110115
void cast_and_copy_to(DST_T* dst, const size_t numel) {
111116
VK_CHECK_COND(numel <= this->numel());
117+
vmaInvalidateAllocation(
118+
vulkan_buffer_.vma_allocator(),
119+
vulkan_buffer_.allocation(),
120+
0u,
121+
VK_WHOLE_SIZE);
112122
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
113123
for (size_t i = 0; i < numel; ++i) {
114124
dst[i] = static_cast<DST_T>(src[i]);

backends/vulkan/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,22 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
468468
return False
469469

470470

471+
def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool:
472+
"""
473+
Returns true if the operator used to compute the given node contains a tensor
474+
with more than 4 dimensions
475+
"""
476+
if is_tensor_node(node) and tensor_node_is_high_dim(node):
477+
return True
478+
479+
for arg_node in node.args:
480+
# pyre-ignore[6]
481+
if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node):
482+
return True
483+
484+
return False
485+
486+
471487
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
472488
primary_arg_idx: Optional[int] = None
473489
for i, arg_node in enumerate(node.args):

0 commit comments

Comments
 (0)