Skip to content

Commit b2c9376

Browse files
author
ssjia
committed
[ET-VK] Layout-flexible impl of quantized binary
This refactors the quantized binary add operator to support all PackedInt8 memory layouts (4W, 4C, 4W4C, 4H4W, 4C1W) instead of being hardcoded to 4W4C. The shader is rewritten to use the block indexing framework (BlockConfig, block_int8x4_load/store) and BufferMetadata for layout-agnostic tensor access, replacing the previous linear dispatch that assumed 4W4C ordering. Key changes: - Renames shader from binary_q8ta_q8ta_q8to to q8ta_binary, and op from add_q8ta_q8ta_q8to to q8ta_add - Shader now uses contiguous_block_idx_to_tensor4d_idx_with_block_config for dispatch and generated load/store functions for layout-flexible int8x4 access - C++ dispatch uses pick_linear_global_wg_with_block_config and passes BufferMetadata UBOs for output and both inputs, plus hashed_layout specialization constants - Moves the test operator into a separate TestQ8taBinary.cpp file that parameterizes on GPUMemoryLayout, testing all 5 layouts - Updates op_registry to accept PACKED_INT8_BUFFER (all layouts) instead of just PACKED_INT8_4W4C_BUFFER This diff was authored with Claude. Differential Revision: [D93000170](https://our.internmc.facebook.com/intern/diff/D93000170/) [ghstack-poisoned]
1 parent f21bfb9 commit b2c9376

File tree

13 files changed

+623
-455
lines changed

13 files changed

+623
-455
lines changed

backends/vulkan/custom_ops_lib.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -564,11 +564,11 @@ def apply_rotary_emb_impl(
564564
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)
565565

566566
########################
567-
## add_q8ta_q8ta_q8to ##
567+
## q8ta_add ##
568568
########################
569569

570570

571-
def add_q8ta_q8ta_q8to_impl(
571+
def q8ta_add_impl(
572572
input_a: torch.Tensor,
573573
input_b: torch.Tensor,
574574
input_a_scale: float,
@@ -598,12 +598,12 @@ def add_q8ta_q8ta_q8to_impl(
598598
return quantized_result
599599

600600

601-
name = "add_q8ta_q8ta_q8to"
601+
name = "q8ta_add"
602602
lib.define(
603603
f"{name}(Tensor input_a, Tensor input_b, float input_a_scale, int input_a_zero_point, float input_b_scale, int input_b_zero_point, float output_scale, int output_zero_point, float alpha) -> Tensor"
604604
)
605-
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
606-
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)
605+
lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd")
606+
q8ta_add_op = getattr(getattr(torch.ops, namespace), name)
607607

608608
#############################
609609
## select_as_symint ##

backends/vulkan/op_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,14 +495,14 @@ def register_torchao_choose_qparams_affine():
495495

496496

497497
# =============================================================================
498-
# QuantizedBinary.cpp
498+
# Q8taBinary.cpp
499499
# =============================================================================
500500

501501

502-
@update_features(exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default)
503-
def register_add_q8ta_q8ta_q8to():
502+
@update_features(exir_ops.edge.et_vk.q8ta_add.default)
503+
def register_q8ta_add():
504504
return OpFeatures(
505-
inputs_storage=utils.PACKED_INT8_4W4C_BUFFER,
505+
inputs_storage=utils.PACKED_INT8_BUFFER,
506506
supports_resize=False,
507507
supports_prepacking=True,
508508
)

backends/vulkan/patterns/quantized_binary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def make_add_q8ta_q8ta_q8to_custom_op(
133133
exir_ops.edge.aten.add.Tensor,
134134
exir_ops.edge.aten.add_.Tensor,
135135
}:
136-
op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default
136+
op_target = exir_ops.edge.et_vk.q8ta_add.default
137137
else:
138138
# For future binary operations, add more mappings here
139139
raise NotImplementedError(

backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl

Lines changed: 0 additions & 76 deletions
This file was deleted.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
${define_active_storage_type("buffer")}
14+
15+
#define op(X, Y) ${OPERATOR}
16+
17+
layout(std430) buffer;
18+
19+
#include "indexing.glslh"
20+
#include "common.glslh"
21+
#include "block_indexing.glslh"
22+
#include "block_int8x4_load.glslh"
23+
#include "block_int8x4_store.glslh"
24+
25+
// Output buffer: packed int8x4 values
26+
${layout_declare_tensor(B, "w", "t_out", "int", "buffer")}
27+
// Input buffers: packed int8x4 values
28+
${layout_declare_tensor(B, "r", "t_in_a", "int", "buffer")}
29+
${layout_declare_tensor(B, "r", "t_in_b", "int", "buffer")}
30+
31+
// Metadata for output and input tensors
32+
${layout_declare_ubo(B, "BufferMetadata", "out_meta")}
33+
${layout_declare_ubo(B, "BufferMetadata", "in_a_meta")}
34+
${layout_declare_ubo(B, "BufferMetadata", "in_b_meta")}
35+
36+
layout(push_constant) uniform restrict Block {
37+
float input_a_scale;
38+
int input_a_zp;
39+
float input_b_scale;
40+
int input_b_zp;
41+
float output_inv_scale;
42+
int output_zp;
43+
};
44+
45+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
46+
47+
${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
48+
${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")}
49+
${layout_declare_spec_const(C, "int", "block_config", "0")}
50+
51+
// Generate loading functions for input buffers
52+
define_load_int8x4_buffer_fns(t_in_a)
53+
define_load_int8x4_buffer_fns(t_in_b)
54+
55+
// Generate storing functions for output buffer
56+
define_store_int8x4_buffer_fns(t_out)
57+
58+
void main() {
59+
// Buffer storage: use linear dispatch
60+
const uint contig_block_idx = gl_GlobalInvocationID.x;
61+
TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config(
62+
out_meta, contig_block_idx, block_config);
63+
64+
if (out_of_bounds(tidx, out_meta)) {
65+
return;
66+
}
67+
68+
const int block_outer_dim = get_block_outer_dim(block_config);
69+
70+
// Load int8x4 blocks from both inputs
71+
ivec4 in_block_a = load_int8x4_block_from_t_in_a(
72+
in_a_meta, tidx, in_layout, block_outer_dim);
73+
ivec4 in_block_b = load_int8x4_block_from_t_in_b(
74+
in_b_meta, tidx, in_layout, block_outer_dim);
75+
76+
ivec4 out_block;
77+
78+
for (int row = 0; row < 4; row++) {
79+
vec4 in_texel_a = unpack_and_dequantize(
80+
in_block_a[row], input_a_scale, input_a_zp);
81+
vec4 in_texel_b = unpack_and_dequantize(
82+
in_block_b[row], input_b_scale, input_b_zp);
83+
84+
vec4 out_texel = op(in_texel_a, in_texel_b);
85+
out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp);
86+
}
87+
88+
// Store to output buffer
89+
store_int8x4_block_to_t_out(
90+
out_meta, tidx, out_layout, block_outer_dim, out_block);
91+
}

backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.yaml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
binary_q8ta_q8ta_q8to:
7+
q8ta_binary:
88
parameter_names_with_default_values:
99
OPERATOR: X + Y
10-
NDIM: 3
11-
DTYPE: float
12-
PACKING: C_packed
13-
IO_STORAGE: buffer
14-
generate_variant_forall:
15-
IO_STORAGE:
16-
- VALUE: buffer
1710
shader_variants:
18-
- NAME: add_q8ta_q8ta_q8to
11+
- NAME: q8ta_add_buffer
1912
OPERATOR: X + Y

0 commit comments

Comments
 (0)