Skip to content

Commit 5e545ea

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][qconv] Add flexible layout impl for im2col
Pull Request resolved: #17249 This implements an im2col-based approach for quantized conv2d, which transforms convolution into matrix multiplication. The im2col transformation extracts sliding windows from the input tensor and reshapes them into a 2D matrix, enabling reuse of the optimized pointwise convolution shader for the compute-intensive portion. Two im2col shaders are added: - `q8ta_im2col.glsl`: Generic shader with layout-agnostic input access via BufferMetadata and specialization constants - `q8ta_im2col_4w4c.glsl`: Optimized shader for 4W4C input layout that exploits the alignment between consecutive width positions and packed channel values The im2col output is always stored in 4W4C layout to match the expected input format of the pointwise convolution shader. The operator is registered as `etvk.q8ta_conv2d_im2col.default` and currently supports non-grouped convolutions where input channels is a multiple of 4. Authored with assistance from Claude. ghstack-source-id: 338638552 @exported-using-ghexport Differential Revision: [D92407723](https://our.internmc.facebook.com/intern/diff/D92407723/)
1 parent 4bd58bb commit 5e545ea

File tree

8 files changed

+631
-0
lines changed

8 files changed

+631
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define PACKED_INT8_OUTPUT_BUFFER
14+
15+
#define TILE_M4 1
16+
#define TILE_N4 1
17+
#define TILE_K4 1
18+
19+
#define TILE_M 4
20+
#define TILE_N 4
21+
#define TILE_K 4
22+
23+
layout(std430) buffer;
24+
25+
#include "indexing.glslh"
26+
#include "conv2d_common.glslh"
27+
28+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)}
30+
31+
// Metadata for im2col output and input tensors (layout-agnostic)
32+
${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")}
33+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
34+
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
35+
36+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
37+
38+
// Layout specialization constants
39+
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
40+
${layout_declare_spec_const(C, "int", "im2col_outp_layout", "CONTIG_LAYOUT_INT")}
41+
42+
layout(push_constant) uniform restrict Block {
43+
int zp;
44+
};
45+
46+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
47+
48+
#include "conv2d_int8_output_tile_store.glslh"
49+
50+
// Compute input tensor index from im2col coordinates
51+
TensorIndex4D get_input_tidx(
52+
const int im2col_w,
53+
const int im2col_h,
54+
const int k_in_group,
55+
const int group_idx) {
56+
TensorIndex4D tidx;
57+
tidx.data.w = 0;
58+
59+
const int c_in_group = k_in_group % conv2d_params.in_channels_per_group;
60+
const int row = k_in_group / conv2d_params.in_channels_per_group;
61+
const int kernel_x = row % conv2d_params.kernel_size.x;
62+
const int kernel_y = row / conv2d_params.kernel_size.x;
63+
64+
tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group;
65+
66+
tidx.data.x = (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x +
67+
(kernel_x * conv2d_params.dilation.x);
68+
tidx.data.y = (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y +
69+
(kernel_y * conv2d_params.dilation.y);
70+
71+
return tidx;
72+
}
73+
74+
// Load a single int8 value from the input tensor using layout-agnostic indexing
75+
int load_input_element(const TensorIndex4D tidx, const int input_zp) {
76+
// Bounds checking
77+
if (any(lessThan(tidx.data, ivec4(0))) ||
78+
any(greaterThanEqual(tidx.data, ivec4(inp.sizes[0])))) {
79+
return input_zp;
80+
}
81+
82+
// Use layout-agnostic indexing to get buffer position
83+
int texel_idx;
84+
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
85+
// For 4C or 4C1W layouts: use tensor4d_idx_to_texel_idx
86+
texel_idx = tensor4d_idx_to_texel_idx(inp, tidx, inp_layout);
87+
} else {
88+
// For 4W4C layout: compute index directly
89+
const int w4 = div_4(tidx.data[0]);
90+
const int c4 = div_4(tidx.data[2]);
91+
const int h_stride = int(inp.strides[0][1]);
92+
const int w_stride = int(inp.strides[0][0]);
93+
texel_idx = (tidx.data[1] * h_stride + w4 * w_stride + c4) * 4 + mod_4(tidx.data[0]);
94+
}
95+
96+
// Load packed int32 containing 4 int8 values
97+
const int packed_input = t_packed_int8_input[texel_idx];
98+
99+
// Extract the appropriate int8 value based on channel offset within texel
100+
const int c_offset = mod_4(tidx.data[2]);
101+
return extract_8bit_from_packed_int_le(packed_input, c_offset);
102+
}
103+
104+
// Load a 4x4 im2col block (4 widths × 4 channels)
105+
ivec4 load_im2col_block(
106+
const int im2col_w_start,
107+
const int im2col_h,
108+
const int k_in_group_start,
109+
const int group_idx) {
110+
ivec4 im2col_block;
111+
112+
for (int r = 0; r < 4; r++) {
113+
const int im2col_w = im2col_w_start + r;
114+
ivec4 row_values;
115+
for (int c = 0; c < 4; c++) {
116+
const int k_in_group = k_in_group_start + c;
117+
118+
if (k_in_group >= conv2d_params.logical_K_per_group) {
119+
row_values[c] = zp;
120+
continue;
121+
}
122+
123+
TensorIndex4D input_tidx =
124+
get_input_tidx(im2col_w, im2col_h, k_in_group, group_idx);
125+
126+
row_values[c] = load_input_element(input_tidx, zp);
127+
}
128+
129+
im2col_block[r] = pack_into_int32(row_values);
130+
}
131+
return im2col_block;
132+
}
133+
134+
void main() {
135+
const int out_buf_idx = int(gl_GlobalInvocationID.x);
136+
137+
const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]);
138+
Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes);
139+
140+
Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx(
141+
out_buf_idx, im2col_block_extents);
142+
143+
if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) {
144+
return;
145+
}
146+
147+
// Convert block index to im2col coordinates
148+
const int im2col_w = mul_4(im2col_block_idx.data.x);
149+
const int im2col_h = im2col_block_idx.data.y;
150+
const int im2col_k = mul_4(im2col_block_idx.data.z);
151+
152+
// Compute group and k offset within group
153+
const int group_idx = im2col_k / conv2d_params.K_per_group;
154+
const int k_in_group = im2col_k % conv2d_params.K_per_group;
155+
156+
// Load the im2col block using layout-agnostic input access
157+
Int8OutTile int8_im2col_tile;
158+
int8_im2col_tile.data[0][0] = load_im2col_block(
159+
im2col_w, im2col_h, k_in_group, group_idx);
160+
161+
// Store to output (4W4C format)
162+
store_packed_int8_output_tile(
163+
int8_im2col_tile, im2col_block_idx, im2col_block_extents);
164+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q8ta_im2col:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
shader_variants:
11+
- NAME: q8ta_im2col
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define PACKED_INT8_OUTPUT_BUFFER
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing.glslh"
18+
#include "conv2d_common.glslh"
19+
20+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)}
21+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)}
22+
23+
// Metadata for im2col output and input tensors (layout-agnostic)
24+
${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")}
25+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
26+
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
27+
28+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
29+
30+
// Layout specialization constants
31+
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
32+
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
33+
34+
layout(push_constant) uniform restrict Block {
35+
int zp;
36+
};
37+
38+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
39+
40+
void main() {
41+
const int out_buf_idx = int(gl_GlobalInvocationID.x);
42+
43+
// Extract sizes from BufferMetadata
44+
const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]);
45+
const ivec4 input_sizes = ivec4(inp.sizes[0]);
46+
47+
// im2col block extents
48+
const int im2col_W4 = div_up_4(im2col_sizes.x);
49+
const int im2col_H = im2col_sizes.y;
50+
const int im2col_Z4 = div_up_4(im2col_sizes.z);
51+
52+
// im2col block index from linear output buffer index
53+
const int c4_idx = out_buf_idx % im2col_Z4;
54+
const int row = out_buf_idx / im2col_Z4;
55+
const int w4_idx = row % im2col_W4;
56+
const int h_idx = row / im2col_W4;
57+
58+
// out of bounds check
59+
if (w4_idx >= im2col_W4 || h_idx >= im2col_H || c4_idx >= im2col_Z4) {
60+
return;
61+
}
62+
63+
const int im2col_w = mul_4(w4_idx);
64+
const int im2col_h = h_idx;
65+
const int im2col_k = mul_4(c4_idx);
66+
67+
const int group_idx = im2col_k / conv2d_params.K_per_group;
68+
const int k_in_group = im2col_k % conv2d_params.K_per_group;
69+
70+
const int c_in_group = k_in_group % conv2d_params.in_channels_per_group;
71+
const int krow = k_in_group / conv2d_params.in_channels_per_group;
72+
const int kernel_x = krow % conv2d_params.kernel_size.x;
73+
const int kernel_y = krow / conv2d_params.kernel_size.x;
74+
75+
// Base input position
76+
const int input_x_base =
77+
(im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x +
78+
(kernel_x * conv2d_params.dilation.x);
79+
const int input_y =
80+
(im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y +
81+
(kernel_y * conv2d_params.dilation.y);
82+
const int input_z =
83+
group_idx * conv2d_params.in_channels_per_group + c_in_group;
84+
85+
// Input tensor extents
86+
const int input_W = input_sizes.x;
87+
const int input_H = input_sizes.y;
88+
const int input_Z4 = div_up_4(input_sizes.z);
89+
90+
const int zp_packed = pack_into_int32(ivec4(zp));
91+
const int z4 = div_4(input_z);
92+
93+
// Check if y and z are in bounds (constant for all 4 width elements)
94+
const bool y_z_in_bounds =
95+
(input_y >= 0 && input_y < input_H && z4 >= 0 && z4 < input_Z4);
96+
97+
// Load 4 elements from input, one for each output width position.
98+
// Each loaded int contains 4 packed int8 channel values.
99+
ivec4 im2col_block;
100+
for (int i = 0; i < 4; i++) {
101+
const int x = input_x_base + i;
102+
if (!y_z_in_bounds || x < 0 || x >= input_W) {
103+
im2col_block[i] = zp_packed;
104+
} else {
105+
const int x4 = div_4(x);
106+
const int x_mod = mod_4(x);
107+
int scalar_idx;
108+
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
109+
scalar_idx = input_y * int(inp.strides[0][1])
110+
+ x * int(inp.strides[0][0])
111+
+ z4 * int(inp.strides[0][2]);
112+
} else {
113+
scalar_idx = mul_4(
114+
input_y * int(inp.strides[0][1])
115+
+ x4 * int(inp.strides[0][0])
116+
+ z4) + x_mod;
117+
}
118+
im2col_block[i] = t_packed_int8_input[scalar_idx];
119+
}
120+
}
121+
122+
// store_packed_int8_output_tile (with TILE_M4=1, TILE_N4=1)
123+
const int buffer_idx = h_idx * int(im2col_outp.strides[0][1])
124+
+ w4_idx * int(im2col_outp.strides[0][0])
125+
+ c4_idx;
126+
127+
if (w4_idx < im2col_W4 && c4_idx < im2col_Z4) {
128+
t_packed_int8_output[buffer_idx] = im2col_block;
129+
}
130+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q8ta_im2col_4w4c:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
shader_variants:
11+
- NAME: q8ta_im2col_4w4c

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,18 @@ void add_q8ta_conv2d_node(
9999
const ValueRef groups,
100100
const ValueRef packed_int8_output);
101101

102+
void add_q8ta_conv2d_pw_node(
103+
ComputeGraph& graph,
104+
const ValueRef packed_int8_input,
105+
const ValueRef input_scale,
106+
const ValueRef input_zp,
107+
const ValueRef packed_weight,
108+
const ValueRef packed_weight_sums,
109+
const ValueRef packed_weight_scales,
110+
const ValueRef output_scale,
111+
const ValueRef output_zp,
112+
const ValueRef bias_data,
113+
const ValueRef packed_bias,
114+
const ValueRef packed_int8_output);
115+
102116
} // namespace vkcompute

0 commit comments

Comments
 (0)