From 34eb0c72e11f18ad3c71ffe86b7da1963f2f618b Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 3 Feb 2026 13:27:25 -0800 Subject: [PATCH 1/2] [ET-VK][quantization] Add layout-flexible clone for int8x4 tensors Implements q8ta_clone, a block-based shader for copying data between int8x4 (quantized) tensors with potentially different memory layouts. This is needed when quantized activations need to be copied between tensors with different packed int8 layouts (e.g., from kPackedInt8_4W4C to kPackedInt8_4C) without going through dequantize/requantize. Implementation details: 1. GLSL Shader (q8ta_clone.glsl): - Uses block-based dispatch pattern matching q8ta_quantize/dequantize - Each thread processes a 4x4 block of int8 values (16 elements) - Uses linear dispatch via gl_GlobalInvocationID.x for buffer tensors - Loads int8x4 blocks using load_int8x4_block_from_t_inp() - Transposes block via transpose_int8x4_block() when input/output have different packed dimensions - Stores using store_int8x4_block_to_t_outp() 2. C++ Dispatch (Q8taClone.cpp): - Creates BlockConfig for both input and output tensors using create_block_config_from_io_packed_dims() and create_block_config_from_other() - Uses pick_linear_global_wg_with_block_config for workgroup sizing - Passes hashed layouts and packed block configs as specialization constants 3. Clone.cpp Integration: - Added check for kInt8x4 dtype on both input and output tensors - Routes to add_q8ta_clone_node() for int8x4 tensor cloning - Preserves existing behavior for all other tensor types 4. Test Infrastructure: - TestQ8taClone.cpp: Custom op that chains quantize -> clone -> dequantize - test_q8ta_clone.cpp: Test driver with 800 test cases - Tests all 25 combinations of input/output quantized layouts - Tests multiple tensor shapes from 1x3x16x16 to 1x128x56x56 Differential Revision: [D92196648](https://our.internmc.facebook.com/intern/diff/D92196648/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/q8ta_clone.glsl | 73 ++++ .../runtime/graph/ops/glsl/q8ta_clone.yaml | 11 + .../vulkan/runtime/graph/ops/impl/Clone.cpp | 8 + .../runtime/graph/ops/impl/Q8taClone.cpp | 65 ++++ .../vulkan/runtime/graph/ops/impl/Q8taClone.h | 24 ++ .../test/custom_ops/impl/TestQ8taClone.cpp | 65 ++++ backends/vulkan/test/custom_ops/targets.bzl | 1 + .../test/custom_ops/test_q8ta_clone.cpp | 344 ++++++++++++++++++ 8 files changed, 591 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taClone.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taClone.h create mode 100644 backends/vulkan/test/custom_ops/impl/TestQ8taClone.cpp create mode 100644 backends/vulkan/test/custom_ops/test_q8ta_clone.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.glsl new file mode 100644 index 00000000000..0006311e13c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.glsl @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output buffer: packed int8x4 values +${layout_declare_tensor(B, "w", "t_outp", "int", "buffer")} +// Input buffer: packed int8x4 values +${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")} + +// Metadata for output tensor +${layout_declare_ubo(B, "BufferMetadata", "outp")} +// Metadata for input tensor +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_block_config", "0")} +${layout_declare_spec_const(C, "int", "outp_block_config", "0")} + +#include "block_indexing.glslh" +#include "block_int8x4_load.glslh" +#include "block_int8x4_store.glslh" + +// Generate loading functions for t_inp buffer +define_load_int8x4_buffer_fns(t_inp) + +// Generate storing functions for t_outp buffer +define_store_int8x4_buffer_fns(t_outp) + +void main() { + TensorIndex4D tidx; + + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + inp, contig_block_idx, inp_block_config); + + if (out_of_bounds(tidx, inp)) { + return; + } + + // Load int8x4 block from input using the thread's block index + const int inp_block_outer_dim = get_block_outer_dim(inp_block_config); + ivec4 int8_block = load_int8x4_block_from_t_inp( + inp, tidx, inp_layout, inp_block_outer_dim); + + // If input and output have different block configs (different packed dims), + // transpose the block to match output's layout + if (inp_block_config != outp_block_config) { + int8_block = transpose_int8x4_block(int8_block); + } + + // Store values to output buffer using output's block config + const int outp_block_outer_dim = get_block_outer_dim(outp_block_config); + store_int8x4_block_to_t_outp( + outp, tidx, outp_layout, outp_block_outer_dim, int8_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.yaml new file mode 100644 index 00000000000..6548ee3b484 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_clone: + parameter_names_with_default_values: + DTYPE: int + shader_variants: + - NAME: q8ta_clone diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index a64cb0143a9..bf2cac6d220 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -132,6 +133,13 @@ void clone(ComputeGraph& graph, const std::vector& args) { const utils::StorageType src_storage = graph.storage_type_of(src); const utils::StorageType dst_storage = graph.storage_type_of(dst); + + // Handle int8x4 (quantized) tensors with block-based clone + if (graph.dtype_of(src) == vkapi::kInt8x4 && + graph.dtype_of(dst) == vkapi::kInt8x4) { + return add_q8ta_clone_node(graph, src, dst); + } + if (src_storage == utils::kTexture3D && dst_storage == utils::kTexture3D) { if (graph.hashed_layout_of(src) == graph.hashed_layout_of(dst)) { return add_clone_node(graph, src, dst); diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taClone.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taClone.cpp new file mode 100644 index 00000000000..d394b53892b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taClone.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void add_q8ta_clone_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef packed_int8_output) { + // Build shader name - always buffer-to-buffer int8x4 clone + std::string kernel_name = "q8ta_clone"; + + // Pass metadata for both output and input tensors + // Both are always buffer-backed int8x4 tensors + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(packed_int8_output)); + param_buffers.append(graph.buffer_meta_ubo(packed_int8_input)); + + // Create block config for output tensor: inner_dim = output's packed_dim + const BlockConfig outp_block_config = create_block_config_from_io_packed_dims( + graph, packed_int8_output, packed_int8_input); + + // Create block config for input tensor: based on outp_block_config but with + // inner_dim = input's packed_dim. If input and output have different packed + // dims, the block axes are transposed. + const BlockConfig inp_block_config = create_block_config_from_other( + graph, packed_int8_input, outp_block_config); + + // Cast block config to ValueRef for pick_*_global_wg_with_block_config + // Use inp_block_config since shader uses inp_block_config for indexing + const ValueRef block_config_ref = + static_cast(inp_block_config.as_packed_int()); + + // Use linear dispatch for buffer-backed int8x4 tensors + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_linear_global_wg_with_block_config, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {graph.hashed_layout_of(packed_int8_input), + graph.hashed_layout_of(packed_int8_output), + inp_block_config.as_packed_int(), + outp_block_config.as_packed_int()}, + // Resize args + {block_config_ref})); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taClone.h b/backends/vulkan/runtime/graph/ops/impl/Q8taClone.h new file mode 100644 index 00000000000..91d14797400 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taClone.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Clone for int8x4 tensors (memory layout agnostic) +// + +void add_q8ta_clone_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taClone.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taClone.cpp new file mode 100644 index 00000000000..9d4c46e79b0 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taClone.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void q8ta_clone_test(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef inp_layout_int = args.at(idx++); + const ValueRef outp_layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + // Extract the layout parameters and cast to GPUMemoryLayout + int32_t inp_layout_value = graph.extract_scalar(inp_layout_int); + utils::GPUMemoryLayout inp_layout = + static_cast(inp_layout_value); + + int32_t outp_layout_value = graph.extract_scalar(outp_layout_int); + utils::GPUMemoryLayout outp_layout = + static_cast(outp_layout_value); + + // Create temporary tensor for quantized input with input layout + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + inp_layout); + + // Create temporary tensor for quantized output with output layout + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + outp_layout); + + // Quantize: FP -> int8x4 with input layout + add_q8ta_quantize_node(graph, fp_input, scale, zero_point, packed_int8_input); + + // Clone: int8x4 (input layout) -> int8x4 (output layout) + add_q8ta_clone_node(graph, packed_int8_input, packed_int8_output); + + // Dequantize: int8x4 with output layout -> FP + add_q8ta_dequantize_node( + graph, packed_int8_output, scale, zero_point, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.q8ta_clone_test.default, q8ta_clone_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 763bfc24b27..596657be94d 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -93,6 +93,7 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("q4gsw_linear") define_custom_op_test_binary("qdq8ta_conv2d_activations") define_custom_op_test_binary("test_q_dq_8bit_per_tensor") + define_custom_op_test_binary("test_q8ta_clone") define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d") define_custom_op_test_binary("test_q8_conv2d_dw") define_custom_op_test_binary("q8ta_q8ta_q8to_add") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_clone.cpp b/backends/vulkan/test/custom_ops/test_q8ta_clone.cpp new file mode 100644 index 00000000000..b872169cc7f --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_clone.cpp @@ -0,0 +1,344 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +// Configuration struct for q8ta clone testing +struct Q8taCloneConfig { + std::vector shape; // Tensor shape (can be any dimensionality) + std::string test_case_name = "placeholder"; + std::string op_name = "q8ta_clone_test"; +}; + +// Utility function to create a test case from a Q8taCloneConfig +TestCase create_test_case_from_config( + const Q8taCloneConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype, + utils::GPUMemoryLayout fp_memory_layout, + utils::GPUMemoryLayout inp_quant_layout, + utils::GPUMemoryLayout outp_quant_layout) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string shape_str = shape_string(config.shape); + std::string test_name = config.test_case_name + " I=" + shape_str + " " + + repr_str(storage_type, fp_memory_layout) + "->" + + repr_str(utils::kBuffer, inp_quant_layout) + "->" + + repr_str(utils::kBuffer, outp_quant_layout); + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) - any dimensionality + ValueSpec input_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec scale(scale_val); + + // Zero point for quantization + int32_t zero_point_val = 0; + ValueSpec zero_point(zero_point_val); + + // Input and output quantized layouts as integers + int32_t inp_layout_int = static_cast(inp_quant_layout); + ValueSpec inp_layout_spec(inp_layout_int); + + int32_t outp_layout_int = static_cast(outp_quant_layout); + ValueSpec outp_layout_spec(outp_layout_int); + + // Output tensor (float) - same shape as input + ValueSpec output_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(scale); + test_case.add_input_spec(zero_point); + test_case.add_input_spec(inp_layout_spec); + test_case.add_input_spec(outp_layout_spec); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + // Use layout-only filter for this test since clone IS the operation being + // tested + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// Generate easy test cases for q8ta_clone operation (for debugging) +std::vector generate_q8ta_clone_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Q8taCloneConfig config = { + {1, 16, 16, 16}, // shape: [N, C, H, W] + "ACCU", // test_case_name + }; + + // FP memory layouts to test + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + // Quantized memory layouts to test + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination (same layout for input and output) + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + // Same layout: should be a simple copy + test_cases.push_back(create_test_case_from_config( + config, + storage_type, + input_dtype, + fp_layout, + quant_layout, + quant_layout)); + } + } + } + } + + return test_cases; +} + +// Generate test cases for q8ta_clone operation +std::vector generate_q8ta_clone_test_cases() { + std::vector test_cases; + + // Shapes to test + std::vector> shapes = { + // Small test cases for correctness + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + // Different tensor sizes + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Performance test cases (larger tensors) + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + {1, 128, 128, 128}, + }; + + // FP memory layouts to test + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + // Quantized memory layouts to test + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + // Test with buffer storage only + std::vector storage_types = {utils::kBuffer}; + + // Generate all combinations + for (const auto& shape : shapes) { + // Generate test case name prefix from shape dimensions + std::string prefix = "ACCU"; + for (const auto& dim : shape) { + if (dim > kRefDimSizeLimit) { + prefix = "PERF"; + break; + } + } + + for (const auto& fp_layout : fp_layouts) { + for (const auto& inp_quant_layout : quant_layouts) { + for (const auto& outp_quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + Q8taCloneConfig config; + config.shape = shape; + config.test_case_name = prefix; + + test_cases.push_back(create_test_case_from_config( + config, + storage_type, + vkapi::kFloat, + fp_layout, + inp_quant_layout, + outp_quant_layout)); + } + } + } + } + } + + return test_cases; +} + +// Reference implementation for q8ta_clone operation +// Since clone just copies data, the result should be the same as +// quantize-dequantize +void q8ta_clone_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& scale_spec = test_case.inputs()[idx++]; + const ValueSpec& zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& inp_layout_spec = test_case.inputs()[idx++]; + const ValueSpec& outp_layout_spec = test_case.inputs()[idx++]; + (void)inp_layout_spec; // Not used in reference implementation + (void)outp_layout_spec; // Not used in reference implementation + + // Extract output specification + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions (arbitrary dimensionality) + auto input_sizes = input_spec.get_tensor_sizes(); + + // Calculate total number of elements + int64_t num_elements = 1; + for (const auto& dim : input_sizes) { + num_elements *= dim; + } + + // Skip for large tensors since computation time will be extremely slow + for (const auto& dim : input_sizes) { + if (dim > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference " + "implementation."); + } + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + + // Extract the randomized scale and zero point values + float scale = scale_spec.get_float_value(); + int32_t zero_point = zero_point_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + // Prepare output data + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + // Perform quantize-clone-dequantize operation on each element + // Clone preserves the quantized values, so result is same as Q-DQ + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize: quantized = round(input / scale + zero_point) + float quantized_float = std::round(input_val / scale) + zero_point; + + // Clamp to quantization range + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize: output = (quantized - zero_point) * scale + float dequantized = (quantized_int - zero_point) * scale; + + ref_data[i] = dequantized; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(false); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8TA Clone Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q8ta_clone_reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_q8ta_clone_easy_cases, +#else + generate_q8ta_clone_test_cases, +#endif + "Q8taClone", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} From 75a2b576dc96ac996f88023a3195ed6440cd2c48 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 3 Feb 2026 13:30:39 -0800 Subject: [PATCH 2/2] Update on "[ET-VK][quantization] Add layout-flexible clone for int8x4 tensors" Implements q8ta_clone, a block-based shader for copying data between int8x4 (quantized) tensors with potentially different memory layouts. This is needed when quantized activations need to be copied between tensors with different packed int8 layouts (e.g., from kPackedInt8_4W4C to kPackedInt8_4C) without going through dequantize/requantize. Implementation details: 1. GLSL Shader (q8ta_clone.glsl): - Uses block-based dispatch pattern matching q8ta_quantize/dequantize - Each thread processes a 4x4 block of int8 values (16 elements) - Uses linear dispatch via gl_GlobalInvocationID.x for buffer tensors - Loads int8x4 blocks using load_int8x4_block_from_t_inp() - Transposes block via transpose_int8x4_block() when input/output have different packed dimensions - Stores using store_int8x4_block_to_t_outp() 2. C++ Dispatch (Q8taClone.cpp): - Creates BlockConfig for both input and output tensors using create_block_config_from_io_packed_dims() and create_block_config_from_other() - Uses pick_linear_global_wg_with_block_config for workgroup sizing - Passes hashed layouts and packed block configs as specialization constants 3. Clone.cpp Integration: - Added check for kInt8x4 dtype on both input and output tensors - Routes to add_q8ta_clone_node() for int8x4 tensor cloning - Preserves existing behavior for all other tensor types 4. Test Infrastructure: - TestQ8taClone.cpp: Custom op that chains quantize -> clone -> dequantize - test_q8ta_clone.cpp: Test driver with 800 test cases - Tests all 25 combinations of input/output quantized layouts - Tests multiple tensor shapes from 1x3x16x16 to 1x128x56x56 Differential Revision: [D92196648](https://our.internmc.facebook.com/intern/diff/D92196648/) [ghstack-poisoned]