Skip to content

Commit 7758fb1

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Layout-flexible impl of quantized binary"
This refactors the quantized binary add operator to support all PackedInt8 memory layouts (4W, 4C, 4W4C, 4H4W, 4C1W) instead of being hardcoded to 4W4C. The shader is rewritten to use the block indexing framework (BlockConfig, block_int8x4_load/store) and BufferMetadata for layout-agnostic tensor access, replacing the previous linear dispatch that assumed 4W4C ordering. Key changes: - Renames shader from binary_q8ta_q8ta_q8to to q8ta_binary, and op from add_q8ta_q8ta_q8to to q8ta_add - Shader now uses contiguous_block_idx_to_tensor4d_idx_with_block_config for dispatch and generated load/store functions for layout-flexible int8x4 access - C++ dispatch uses pick_linear_global_wg_with_block_config and passes BufferMetadata UBOs for output and both inputs, plus hashed_layout specialization constants - Moves the test operator into a separate TestQ8taBinary.cpp file that parameterizes on GPUMemoryLayout, testing all 5 layouts - Updates op_registry to accept PACKED_INT8_BUFFER (all layouts) instead of just PACKED_INT8_4W4C_BUFFER This diff was authored with Claude. Differential Revision: [D93000170](https://our.internmc.facebook.com/intern/diff/D93000170/) [ghstack-poisoned]
1 parent f21bfb9 commit 7758fb1

File tree

5 files changed

+54
-0
lines changed

5 files changed

+54
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class OpFeatures:
3737
# bool indicating if the operator has a resize function, which allows it to
3838
# support models with dynamic shape
3939
"supports_resize",
40+
# bool indicating if the operator supports tensors with more than 4 dimensions
41+
"supports_highdim",
4042
# bool indicating if the operator handles its own prepacking. If this is True,
4143
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
4244
# of the op.
@@ -60,6 +62,7 @@ def __init__(
6062
Union[utils.TensorRepSet, List[utils.TensorRepSet]]
6163
] = None,
6264
supports_resize: bool = False,
65+
supports_highdim: bool = False,
6366
supports_prepacking: bool = False,
6467
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
6568
pick_io_storage_fn: Optional[Callable] = None,
@@ -85,6 +88,7 @@ def __init__(
8588
self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0])
8689

8790
self.supports_resize = supports_resize
91+
self.supports_highdim = supports_highdim
8892
self.supports_prepacking = supports_prepacking
8993

9094
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
@@ -239,6 +243,7 @@ def register_binaryop_cpp_ops():
239243
inputs_storage=utils.ANY_STORAGE,
240244
inputs_dtypes=utils.FP_INT_T,
241245
supports_resize=True,
246+
supports_highdim=True,
242247
)
243248

244249

@@ -253,6 +258,7 @@ def register_pow_tensor_scalar():
253258
inputs_storage=utils.ANY_STORAGE,
254259
inputs_dtypes=utils.FP_T,
255260
supports_resize=True,
261+
supports_highdim=True,
256262
)
257263

258264

@@ -635,6 +641,7 @@ def register_reduce_cpp_ops():
635641
inputs_storage=utils.ANY_TEXTURE,
636642
inputs_dtypes=utils.FP_T,
637643
supports_resize=True,
644+
supports_highdim=True,
638645
are_node_inputs_supported_fn=is_reduce_node_supported,
639646
pick_io_storage_fn=pick_storage_for_reduce,
640647
)
@@ -656,6 +663,7 @@ def register_argreduce_cpp_ops():
656663
inputs_storage=utils.ANY_TEXTURE,
657664
inputs_dtypes=utils.FP_T,
658665
supports_resize=True,
666+
supports_highdim=True,
659667
are_node_inputs_supported_fn=is_reduce_node_supported,
660668
pick_io_storage_fn=pick_storage_for_reduce,
661669
)
@@ -851,6 +859,7 @@ def register_apply_rotary_emb():
851859
inputs_storage=utils.CONTIGUOUS_ANY,
852860
inputs_dtypes=utils.FP_T,
853861
supports_resize=True,
862+
supports_highdim=True,
854863
)
855864

856865

@@ -874,6 +883,7 @@ def register_permute_copy():
874883
inputs_storage=utils.ANY_STORAGE,
875884
inputs_dtypes=utils.FP_INT_BOOL_T,
876885
supports_resize=True,
886+
supports_highdim=True,
877887
)
878888

879889

@@ -888,6 +898,7 @@ def register_view_copy():
888898
inputs_storage=utils.ANY_STORAGE,
889899
inputs_dtypes=utils.FP_INT_BOOL_T,
890900
supports_resize=True,
901+
supports_highdim=True,
891902
)
892903

893904

@@ -897,6 +908,7 @@ def register_to_dim_order_copy():
897908
inputs_storage=utils.ANY_BUFFER,
898909
inputs_dtypes=utils.FP_INT_BOOL_T,
899910
supports_resize=True,
911+
supports_highdim=True,
900912
)
901913

902914

@@ -911,6 +923,7 @@ def register_squeeze_copy():
911923
inputs_storage=utils.ANY_STORAGE,
912924
inputs_dtypes=utils.FP_INT_BOOL_T,
913925
supports_resize=True,
926+
supports_highdim=True,
914927
)
915928

916929

@@ -925,6 +938,7 @@ def register_unsqueeze_copy():
925938
inputs_storage=utils.ANY_STORAGE,
926939
inputs_dtypes=utils.FP_INT_BOOL_T,
927940
supports_resize=True,
941+
supports_highdim=True,
928942
)
929943

930944

@@ -939,6 +953,7 @@ def register_clone():
939953
inputs_storage=utils.ANY_STORAGE,
940954
inputs_dtypes=utils.FP_INT_BOOL_T,
941955
supports_resize=True,
956+
supports_highdim=True,
942957
)
943958

944959

@@ -978,6 +993,7 @@ def register_expand_copy():
978993
inputs_storage=utils.ANY_BUFFER,
979994
inputs_dtypes=utils.FP_INT_BOOL_T,
980995
supports_resize=False,
996+
supports_highdim=True,
981997
)
982998

983999

@@ -1006,6 +1022,7 @@ def register_select_copy():
10061022
inputs_storage=utils.ANY_STORAGE,
10071023
inputs_dtypes=utils.FP_INT_BOOL_T,
10081024
supports_resize=True,
1025+
supports_highdim=True,
10091026
)
10101027

10111028

@@ -1020,6 +1037,7 @@ def register_slice_copy():
10201037
inputs_storage=utils.ANY_STORAGE,
10211038
inputs_dtypes=utils.FP_INT_BOOL_T,
10221039
supports_resize=True,
1040+
supports_highdim=True,
10231041
)
10241042

10251043

@@ -1034,6 +1052,7 @@ def register_split_with_sizes_copy():
10341052
inputs_storage=utils.ANY_STORAGE,
10351053
inputs_dtypes=utils.FP_INT_BOOL_T,
10361054
supports_resize=True,
1055+
supports_highdim=True,
10371056
)
10381057

10391058

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
266266
self.log_skip(node, "op args not supported")
267267
return False
268268

269+
if not features.supports_highdim and utils.op_contains_high_dim_tensor(node):
270+
self.log_skip(node, "op does not support high dim tensors")
271+
return False
272+
269273
if self.require_dynamic_shapes and not features.supports_resize:
270274
self.log_skip(node, "no dynamic shape support")
271275
return False

backends/vulkan/runtime/api/containers/StagingBuffer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ void StagingBuffer::cast_half_to_float_and_copy_from(
159159
for (size_t i = 0; i < numel; ++i) {
160160
dst[i] = half_to_float(src[i]);
161161
}
162+
vmaFlushAllocation(
163+
vulkan_buffer_.vma_allocator(),
164+
vulkan_buffer_.allocation(),
165+
0u,
166+
VK_WHOLE_SIZE);
162167
}
163168

164169
void StagingBuffer::cast_float_to_half_and_copy_to(

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ class StagingBuffer final {
8888
for (size_t i = 0; i < numel; ++i) {
8989
dst[i] = static_cast<DST_T>(src[i]);
9090
}
91+
vmaFlushAllocation(
92+
vulkan_buffer_.vma_allocator(),
93+
vulkan_buffer_.allocation(),
94+
0u,
95+
VK_WHOLE_SIZE);
9196
}
9297

9398
void cast_half_to_float_and_copy_from(
@@ -109,6 +114,11 @@ class StagingBuffer final {
109114
template <typename SRC_T, typename DST_T>
110115
void cast_and_copy_to(DST_T* dst, const size_t numel) {
111116
VK_CHECK_COND(numel <= this->numel());
117+
vmaInvalidateAllocation(
118+
vulkan_buffer_.vma_allocator(),
119+
vulkan_buffer_.allocation(),
120+
0u,
121+
VK_WHOLE_SIZE);
112122
const SRC_T* src = reinterpret_cast<const SRC_T*>(data());
113123
for (size_t i = 0; i < numel; ++i) {
114124
dst[i] = static_cast<DST_T>(src[i]);

backends/vulkan/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,22 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
468468
return False
469469

470470

471+
def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool:
472+
"""
473+
Returns true if the operator used to compute the given node contains a tensor
474+
with more than 4 dimensions
475+
"""
476+
if is_tensor_node(node) and tensor_node_is_high_dim(node):
477+
return True
478+
479+
for arg_node in node.args:
480+
# pyre-ignore[6]
481+
if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node):
482+
return True
483+
484+
return False
485+
486+
471487
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
472488
primary_arg_idx: Optional[int] = None
473489
for i, arg_node in enumerate(node.args):

0 commit comments

Comments
 (0)