Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

${define_active_storage_type("buffer")}

layout(std430) buffer;

#include "indexing.glslh"

// Output buffer: packed int8x4 values
${layout_declare_tensor(B, "w", "t_outp", "int", "buffer")}
// Input buffer: packed int8x4 values
${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")}

// Metadata for output tensor
${layout_declare_ubo(B, "BufferMetadata", "outp")}
// Metadata for input tensor
${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
${layout_declare_spec_const(C, "int", "inp_block_config", "0")}
${layout_declare_spec_const(C, "int", "outp_block_config", "0")}

#include "block_indexing.glslh"
#include "block_int8x4_load.glslh"
#include "block_int8x4_store.glslh"

// Generate loading functions for t_inp buffer
define_load_int8x4_buffer_fns(t_inp)

// Generate storing functions for t_outp buffer
define_store_int8x4_buffer_fns(t_outp)

void main() {
TensorIndex4D tidx;

// Buffer storage: use linear dispatch
const uint contig_block_idx = gl_GlobalInvocationID.x;
tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config(
inp, contig_block_idx, inp_block_config);

if (out_of_bounds(tidx, inp)) {
return;
}

// Load int8x4 block from input using the thread's block index
const int inp_block_outer_dim = get_block_outer_dim(inp_block_config);
ivec4 int8_block = load_int8x4_block_from_t_inp(
inp, tidx, inp_layout, inp_block_outer_dim);

// If input and output have different block configs (different packed dims),
// transpose the block to match output's layout
if (inp_block_config != outp_block_config) {
int8_block = transpose_int8x4_block(int8_block);
}

// Store values to output buffer using output's block config
const int outp_block_outer_dim = get_block_outer_dim(outp_block_config);
store_int8x4_block_to_t_outp(
outp, tidx, outp_layout, outp_block_outer_dim, int8_block);
}
11 changes: 11 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/q8ta_clone.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q8ta_clone:
parameter_names_with_default_values:
DTYPE: int
shader_variants:
- NAME: q8ta_clone
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/backends/vulkan/runtime/graph/Logging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taClone.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
Expand Down Expand Up @@ -132,6 +133,13 @@ void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {

const utils::StorageType src_storage = graph.storage_type_of(src);
const utils::StorageType dst_storage = graph.storage_type_of(dst);

// Handle int8x4 (quantized) tensors with block-based clone
if (graph.dtype_of(src) == vkapi::kInt8x4 &&
graph.dtype_of(dst) == vkapi::kInt8x4) {
return add_q8ta_clone_node(graph, src, dst);
}

if (src_storage == utils::kTexture3D && dst_storage == utils::kTexture3D) {
if (graph.hashed_layout_of(src) == graph.hashed_layout_of(dst)) {
return add_clone_node(graph, src, dst);
Expand Down
68 changes: 68 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Q8taClone.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taClone.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_q8ta_clone_node(
ComputeGraph& graph,
const ValueRef packed_int8_input,
const ValueRef packed_int8_output) {
VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4);
VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4);

// Build shader name - always buffer-to-buffer int8x4 clone
std::string kernel_name = "q8ta_clone";

// Pass metadata for both output and input tensors
// Both are always buffer-backed int8x4 tensors
vkapi::ParamsBindList param_buffers;
param_buffers.append(graph.buffer_meta_ubo(packed_int8_output));
param_buffers.append(graph.buffer_meta_ubo(packed_int8_input));

// Create block config for output tensor: inner_dim = output's packed_dim
const BlockConfig outp_block_config = create_block_config_from_io_packed_dims(
graph, packed_int8_output, packed_int8_input);

// Create block config for input tensor: based on outp_block_config but with
// inner_dim = input's packed_dim. If input and output have different packed
// dims, the block axes are transposed.
const BlockConfig inp_block_config = create_block_config_from_other(
graph, packed_int8_input, outp_block_config);

// Cast block config to ValueRef for pick_*_global_wg_with_block_config
// Use inp_block_config since shader uses inp_block_config for indexing
const ValueRef block_config_ref =
static_cast<ValueRef>(inp_block_config.as_packed_int());

// Use linear dispatch for buffer-backed int8x4 tensors
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
pick_linear_global_wg_with_block_config,
pick_square_local_wg_with_block_config,
// Inputs and Outputs
{{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}},
// Shader params buffers
param_buffers,
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(packed_int8_input),
graph.hashed_layout_of(packed_int8_output),
inp_block_config.as_packed_int(),
outp_block_config.as_packed_int()},
// Resize args
{block_config_ref}));
}

} // namespace vkcompute
24 changes: 24 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Q8taClone.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace vkcompute {

//
// Clone for int8x4 tensors (memory layout agnostic)
//

void add_q8ta_clone_node(
ComputeGraph& graph,
const ValueRef packed_int8_input,
const ValueRef packed_int8_output);

} // namespace vkcompute
1 change: 1 addition & 0 deletions backends/vulkan/test/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ if(TARGET vulkan_backend)
add_operator_prototype(choose_qparams_per_row)
add_operator_prototype(qdq8ta_conv2d_activations)
add_operator_prototype(test_q8ta_qdq)
add_operator_prototype(test_q8ta_clone)
add_operator_prototype(q8ta_q8csw_q8to_conv2d)
add_operator_prototype(test_q8ta_conv2d_dw)
add_operator_prototype(q8ta_q8ta_q8to_add)
Expand Down
65 changes: 65 additions & 0 deletions backends/vulkan/test/custom_ops/impl/TestQ8taClone.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taClone.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h>

namespace vkcompute {

void q8ta_clone_test(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int32_t idx = 0;
const ValueRef fp_input = args.at(idx++);
const ValueRef scale = args.at(idx++);
const ValueRef zero_point = args.at(idx++);
const ValueRef inp_layout_int = args.at(idx++);
const ValueRef outp_layout_int = args.at(idx++);
const ValueRef fp_output = args.at(idx++);

// Extract the layout parameters and cast to GPUMemoryLayout
int32_t inp_layout_value = graph.extract_scalar<int32_t>(inp_layout_int);
utils::GPUMemoryLayout inp_layout =
static_cast<utils::GPUMemoryLayout>(inp_layout_value);

int32_t outp_layout_value = graph.extract_scalar<int32_t>(outp_layout_int);
utils::GPUMemoryLayout outp_layout =
static_cast<utils::GPUMemoryLayout>(outp_layout_value);

// Create temporary tensor for quantized input with input layout
TmpTensor packed_int8_input(
&graph,
graph.sizes_of(fp_input),
vkapi::kInt8x4,
utils::kBuffer,
inp_layout);

// Create temporary tensor for quantized output with output layout
TmpTensor packed_int8_output(
&graph,
graph.sizes_of(fp_input),
vkapi::kInt8x4,
utils::kBuffer,
outp_layout);

// Quantize: FP -> int8x4 with input layout
add_q8ta_quantize_node(graph, fp_input, scale, zero_point, packed_int8_input);

// Clone: int8x4 (input layout) -> int8x4 (output layout)
add_q8ta_clone_node(graph, packed_int8_input, packed_int8_output);

// Dequantize: int8x4 with output layout -> FP
add_q8ta_dequantize_node(
graph, packed_int8_output, scale, zero_point, fp_output);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(test_etvk.q8ta_clone_test.default, q8ta_clone_test);
}

} // namespace vkcompute
1 change: 1 addition & 0 deletions backends/vulkan/test/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def define_common_targets(is_fbcode = False):
define_custom_op_test_binary("q4gsw_linear")
define_custom_op_test_binary("qdq8ta_conv2d_activations")
define_custom_op_test_binary("test_q8ta_qdq")
define_custom_op_test_binary("test_q8ta_clone")
define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d")
define_custom_op_test_binary("test_q8ta_conv2d_dw")
define_custom_op_test_binary("q8ta_q8ta_q8to_add")
Loading
Loading