|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +#include <cuda.h> |
| 5 | +#include <cuda_fp16.h> |
| 6 | +#include <cuda_runtime.h> |
| 7 | +#include <stdint.h> |
| 8 | +#include <thrust/binary_search.h> |
| 9 | +#include <thrust/execution_policy.h> |
| 10 | + |
| 11 | +#include "config.cuh" |
| 12 | +#include "types.cuh" |
| 13 | + |
| 14 | +constexpr uint32_t MAX_CACHED_RUNS = 512; |
| 15 | + |
| 16 | +/// Binary search for the first element strictly greater than `value`. |
| 17 | +/// |
| 18 | +/// Uses `thrust::upper_bound` with sequential execution policy. `thrust::seq` |
| 19 | +/// is chosen as the binary search runs on a single GPU thread. This is |
| 20 | +/// preferred over `thrust::device` as this would spawn an additional kernel |
| 21 | +/// launch. |
| 22 | +/// See: https://nvidia.github.io/cccl/thrust/api/group__binary__search_1gac85cc9ea00f4bdd8f80ad25fff16741d.html#thrust-upper-bound |
| 23 | +/// |
| 24 | +/// Returns the index of the first element that is greater than `value`, or |
| 25 | +/// `len` if no such element exists. |
| 26 | +template<typename T> |
| 27 | +__device__ __forceinline__ uint64_t upper_bound(const T *data, uint64_t len, uint64_t value) { |
| 28 | + |
| 29 | + auto it = thrust::upper_bound(thrust::seq, data, data + len, value); |
| 30 | + return it - data; |
| 31 | +} |
| 32 | + |
| 33 | + |
| 34 | +// Decodes run-end encoded data on the GPU. |
| 35 | +// |
| 36 | +// Run-end stores data as pairs of (value, end_position) where each run contains |
| 37 | +// repeated values from the previous end position to the current end position. |
| 38 | +// |
| 39 | +// Steps: |
| 40 | +// 1. Each CUDA block processes a contiguous chunk of output elements (elements_per_block). |
| 41 | +// |
| 42 | +// 2. Block Initialization (Thread 0 only): |
| 43 | +// - Compute the global position range [block_start + offset, block_end + offset) for this block |
| 44 | +// - Use binary search (upper_bound) to find the first and last runs that overlap this range |
| 45 | +// - Store the run range in shared memory (block_first_run, block_num_runs) |
| 46 | +// |
| 47 | +// 3. Shared Memory Caching: |
| 48 | +// - If the number of runs for this block fits in shared memory (< MAX_CACHED_RUNS), |
| 49 | +// all threads cooperatively load the relevant ends[] and values[] into shared memory |
| 50 | +// - This is to reduce global memory access during decoding |
| 51 | +// |
| 52 | +// 4. Decoding: |
| 53 | +// a) Cached path: Each thread decodes multiple elements using a forward scan. |
| 54 | +// Since thread positions are strided (idx += blockDim.x), and positions are monotonically |
| 55 | +// increasing across iterations, we maintain a current_run index that only moves forward. |
| 56 | +// |
| 57 | +// b) Fallback path: If too many runs span this block (exceeds MAX_CACHED_RUNS), |
| 58 | +// fall back to binary search in global memory for each element. |
| 59 | +// |
| 60 | +// TODO(0ax1): Investigate whether there are faster solutions. |
| 61 | +template<typename ValueT, typename EndsT> |
| 62 | +__device__ void runend_decode_kernel( |
| 63 | + const EndsT *const __restrict ends, |
| 64 | + uint64_t num_runs, |
| 65 | + const ValueT *const __restrict values, |
| 66 | + uint64_t offset, |
| 67 | + uint64_t output_len, |
| 68 | + ValueT *const __restrict output |
| 69 | +) { |
| 70 | + __shared__ EndsT shared_ends[MAX_CACHED_RUNS]; |
| 71 | + __shared__ ValueT shared_values[MAX_CACHED_RUNS]; |
| 72 | + __shared__ uint64_t block_first_run; |
| 73 | + __shared__ uint32_t block_num_runs; |
| 74 | + |
| 75 | + const uint32_t elements_per_block = blockDim.x * ELEMENTS_PER_THREAD; |
| 76 | + const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * elements_per_block; |
| 77 | + const uint64_t block_end = min(block_start + elements_per_block, output_len); |
| 78 | + |
| 79 | + if (block_start >= output_len) return; |
| 80 | + |
| 81 | + // Thread 0 finds the run range for this block. |
| 82 | + if (threadIdx.x == 0) { |
| 83 | + uint64_t first_pos = block_start + offset; |
| 84 | + uint64_t last_pos = (block_end - 1) + offset; |
| 85 | + |
| 86 | + uint64_t first_run = upper_bound(ends, num_runs, first_pos); |
| 87 | + uint64_t last_run = upper_bound(ends, num_runs, last_pos); |
| 88 | + |
| 89 | + block_first_run = first_run; |
| 90 | + block_num_runs = static_cast<uint32_t>(min(last_run - first_run + 1, static_cast<uint64_t>(MAX_CACHED_RUNS))); |
| 91 | + } |
| 92 | + __syncthreads(); |
| 93 | + |
| 94 | + // Cooperatively load ends and values into shared memory. |
| 95 | + if (block_num_runs < MAX_CACHED_RUNS) { |
| 96 | + for (uint32_t i = threadIdx.x; i < block_num_runs; i += blockDim.x) { |
| 97 | + shared_ends[i] = ends[block_first_run + i]; |
| 98 | + shared_values[i] = values[block_first_run + i]; |
| 99 | + } |
| 100 | + } |
| 101 | + __syncthreads(); |
| 102 | + |
| 103 | + if (block_num_runs < MAX_CACHED_RUNS) { |
| 104 | + uint32_t current_run = 0; |
| 105 | + for (uint64_t idx = block_start + threadIdx.x; idx < block_end; idx += blockDim.x) { |
| 106 | + uint64_t pos = idx + offset; |
| 107 | + |
| 108 | + // Scan forward to find the run containing this position |
| 109 | + while (current_run < block_num_runs && static_cast<uint64_t>(shared_ends[current_run]) <= pos) { |
| 110 | + current_run++; |
| 111 | + } |
| 112 | + |
| 113 | + output[idx] = shared_values[current_run < block_num_runs ? current_run : block_num_runs - 1]; |
| 114 | + } |
| 115 | + } else { |
| 116 | + // Fallback for blocks with very short runs. Search the full `num_runs` |
| 117 | + // array. `block_num_runs` is clamped to `MAX_CACHED_RUNS`. |
| 118 | + for (uint64_t idx = block_start + threadIdx.x; idx < block_end; idx += blockDim.x) { |
| 119 | + uint64_t pos = idx + offset; |
| 120 | + uint64_t run_idx = upper_bound(ends, num_runs, pos); |
| 121 | + if (run_idx >= num_runs) run_idx = num_runs - 1; |
| 122 | + output[idx] = values[run_idx]; |
| 123 | + } |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +#define GENERATE_RUNEND_KERNEL(value_suffix, ValueType, ends_suffix, EndsType) \ |
| 128 | +extern "C" __global__ void runend_##value_suffix##_##ends_suffix( \ |
| 129 | + const EndsType *const __restrict ends, \ |
| 130 | + uint64_t num_runs, \ |
| 131 | + const ValueType *const __restrict values, \ |
| 132 | + uint64_t offset, \ |
| 133 | + uint64_t output_len, \ |
| 134 | + ValueType *const __restrict output \ |
| 135 | +) { \ |
| 136 | + runend_decode_kernel<ValueType, EndsType>(ends, num_runs, values, offset, output_len, output); \ |
| 137 | +} |
| 138 | + |
| 139 | +#define GENERATE_RUNEND_KERNELS_FOR_VALUE(value_suffix, ValueType) \ |
| 140 | + GENERATE_RUNEND_KERNEL(value_suffix, ValueType, u8, uint8_t) \ |
| 141 | + GENERATE_RUNEND_KERNEL(value_suffix, ValueType, u16, uint16_t) \ |
| 142 | + GENERATE_RUNEND_KERNEL(value_suffix, ValueType, u32, uint32_t) \ |
| 143 | + GENERATE_RUNEND_KERNEL(value_suffix, ValueType, u64, uint64_t) |
| 144 | + |
| 145 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(u8, uint8_t) |
| 146 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(i8, int8_t) |
| 147 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(u16, uint16_t) |
| 148 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(i16, int16_t) |
| 149 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(u32, uint32_t) |
| 150 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(i32, int32_t) |
| 151 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(u64, uint64_t) |
| 152 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(i64, int64_t) |
| 153 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(f16, __half) |
| 154 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(f32, float) |
| 155 | +GENERATE_RUNEND_KERNELS_FOR_VALUE(f64, double) |
0 commit comments