Skip to content

Commit 782d650

Browse files
authored
feat[cuda]: patches kernel (#6231)
Apply patches in-place for BP and ALP. Added unit tests, and also added patches to the existing BP/ALP tests to verify it works --------- Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 76f19cb commit 782d650

File tree

10 files changed

+404
-41
lines changed

10 files changed

+404
-41
lines changed

vortex-cuda/benches/dict_cuda.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ fn launch_dict_kernel_timed<V: cudarc::driver::DeviceRepr, I: cudarc::driver::De
101101
let events = vortex_cuda::launch_cuda_kernel!(
102102
execution_ctx: cuda_ctx,
103103
module: "dict",
104-
ptypes: &[value_ptype.to_string().as_str(), code_ptype.to_string().as_str()],
104+
ptypes: &[value_ptype, code_ptype],
105105
launch_args: [codes_view, codes_len_u64, values_view, output_view],
106106
event_recording: CU_EVENT_BLOCKING_SYNC,
107107
array_len: codes_len

vortex-cuda/benches/for_cuda.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ fn launch_for_kernel_timed_u8(
8989
let events = vortex_cuda::launch_cuda_kernel!(
9090
execution_ctx: cuda_ctx,
9191
module: "for",
92-
ptypes: &[for_array.ptype().to_string().as_str()],
92+
ptypes: &[for_array.ptype()],
9393
launch_args: [device_data, reference, array_len_u64],
9494
event_recording: CU_EVENT_BLOCKING_SYNC,
9595
array_len: for_array.len()
@@ -110,7 +110,7 @@ fn launch_for_kernel_timed_u16(
110110
let events = vortex_cuda::launch_cuda_kernel!(
111111
execution_ctx: cuda_ctx,
112112
module: "for",
113-
ptypes: &[for_array.ptype().to_string().as_str()],
113+
ptypes: &[for_array.ptype()],
114114
launch_args: [device_data, reference, array_len_u64],
115115
event_recording: CU_EVENT_BLOCKING_SYNC,
116116
array_len: for_array.len()
@@ -131,7 +131,7 @@ fn launch_for_kernel_timed_u32(
131131
let events = vortex_cuda::launch_cuda_kernel!(
132132
execution_ctx: cuda_ctx,
133133
module: "for",
134-
ptypes: &[for_array.ptype().to_string().as_str()],
134+
ptypes: &[for_array.ptype()],
135135
launch_args: [device_data, reference, array_len_u64],
136136
event_recording: CU_EVENT_BLOCKING_SYNC,
137137
array_len: for_array.len()
@@ -152,7 +152,7 @@ fn launch_for_kernel_timed_u64(
152152
let events = vortex_cuda::launch_cuda_kernel!(
153153
execution_ctx: cuda_ctx,
154154
module: "for",
155-
ptypes: &[for_array.ptype().to_string().as_str()],
155+
ptypes: &[for_array.ptype()],
156156
launch_args: [device_data, reference, array_len_u64],
157157
event_recording: CU_EVENT_BLOCKING_SYNC,
158158
array_len: for_array.len()

vortex-cuda/kernels/src/config.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33

44
#pragma once
55

6+
#include <stdint.h>
7+
68
// Kernel launch configuration constants.
79
// Must match the Rust launch config in src/kernel/mod.rs.
810
//
911
// With THREADS_PER_BLOCK=64 (set by Rust) and ELEMENTS_PER_THREAD=32:
1012
// elements_per_block = 64 * 32 = 2048
1113
// grid_dim = ceil(array_len / 2048)
1214
constexpr uint32_t ELEMENTS_PER_THREAD = 32;
15+
16+
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
17+
18+
#define START_ELEM(idx, len) MIN((idx) * ELEMENTS_PER_THREAD, (len))
19+
#define STOP_ELEM(idx, len) MIN(START_ELEM(idx, len) + ELEMENTS_PER_THREAD, (len))

vortex-cuda/kernels/src/patches.cu

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include "config.cuh"
5+
6+
// TODO(aduffy): this is very naive. In the future we need to
7+
// transpose the patches, see G-ALP paper.
8+
// Apply patches to a source array
9+
template<typename ValueT, typename IndexT>
10+
__device__ void patches(
11+
ValueT *const values,
12+
const IndexT *const patchIndices,
13+
const ValueT *const patchValues,
14+
uint64_t patchesLen
15+
) {
16+
const uint64_t worker = blockIdx.x * blockDim.x + threadIdx.x;
17+
const uint64_t startElem = START_ELEM(worker, patchesLen);
18+
const uint64_t stopElem = START_ELEM(worker, patchesLen);
19+
20+
if (startElem >= patchesLen) {
21+
return;
22+
}
23+
24+
for (uint64_t idx = startElem; idx < stopElem; idx++) {
25+
const IndexT patchIdx = patchIndices[idx];
26+
const ValueT patchVal = patchValues[idx];
27+
28+
const size_t valueIdx = static_cast<size_t>(patchIdx);
29+
values[valueIdx] = patchVal;
30+
}
31+
}
32+
33+
#define GENERATE_PATCHES_KERNEL(ValueT, value_suffix, IndexT, index_suffix) \
34+
extern "C" __global__ void patches_##value_suffix##_##index_suffix( \
35+
ValueT *const values, \
36+
const IndexT *const patchIndices, \
37+
const ValueT *const patchValues, \
38+
uint64_t patchesLen \
39+
) { \
40+
patches(values, patchIndices, patchValues, patchesLen); \
41+
}
42+
43+
#define GENERATE_PATCHES_KERNEL_FOR_VALUE(ValueT, value_suffix) \
44+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint8_t, u8) \
45+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint16_t, u16) \
46+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint32_t, u32) \
47+
GENERATE_PATCHES_KERNEL(ValueT, value_suffix, uint64_t, u64)
48+
49+
50+
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint8_t, u8)
51+
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint16_t, u16)
52+
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint32_t, u32)
53+
GENERATE_PATCHES_KERNEL_FOR_VALUE(uint64_t, u64)
54+
55+
GENERATE_PATCHES_KERNEL_FOR_VALUE(int8_t, i8)
56+
GENERATE_PATCHES_KERNEL_FOR_VALUE(int16_t, i16)
57+
GENERATE_PATCHES_KERNEL_FOR_VALUE(int32_t, i32)
58+
GENERATE_PATCHES_KERNEL_FOR_VALUE(int64_t, i64)
59+
60+
GENERATE_PATCHES_KERNEL_FOR_VALUE(float, f32)
61+
GENERATE_PATCHES_KERNEL_FOR_VALUE(double, f64)

vortex-cuda/src/device_buffer.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::stream::await_stream_callback;
2727
/// A [`DeviceBuffer`] wrapping a CUDA GPU allocation.
2828
///
2929
/// Like the host `BufferHandle` variant, all slicing/referencing works in terms of byte units.
30+
#[derive(Clone)]
3031
pub struct CudaDeviceBuffer {
3132
allocation: Arc<dyn private::DeviceAllocation>,
3233
/// Offset in bytes from the start of the allocation
@@ -39,8 +40,6 @@ pub struct CudaDeviceBuffer {
3940
alignment: Alignment,
4041
}
4142

42-
// We can call the sys methods, it's just a lot of extra code...fuck that lol
43-
4443
mod private {
4544
use std::fmt::Debug;
4645
use std::sync::Arc;

vortex-cuda/src/kernel/arrays/dict.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ async fn execute_dict_prim_typed<V: DeviceRepr + NativePType, I: DeviceRepr + Na
129129
let _cuda_events = crate::launch_cuda_kernel!(
130130
execution_ctx: ctx,
131131
module: "dict",
132-
ptypes: &[value_ptype.to_string().as_str(), I::PTYPE.to_string().as_str()],
132+
ptypes: &[value_ptype, I::PTYPE],
133133
launch_args: [codes_view, codes_len_u64, values_view, output_view],
134134
event_recording: cudarc::driver::sys::CUevent_flags::CU_EVENT_DISABLE_TIMING,
135135
array_len: codes_len

vortex-cuda/src/kernel/encodings/alp.rs

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use vortex_array::arrays::PrimitiveArrayParts;
2020
use vortex_array::buffer::BufferHandle;
2121
use vortex_cuda_macros::cuda_tests;
2222
use vortex_dtype::NativePType;
23+
use vortex_dtype::match_each_unsigned_integer_ptype;
2324
use vortex_error::VortexResult;
2425
use vortex_error::vortex_err;
2526

@@ -28,6 +29,7 @@ use crate::CudaDeviceBuffer;
2829
use crate::executor::CudaArrayExt;
2930
use crate::executor::CudaExecute;
3031
use crate::executor::CudaExecutionCtx;
32+
use crate::kernel::patches::execute_patches;
3133
use crate::launch_cuda_kernel_impl;
3234

3335
/// CUDA decoder for ALP (Adaptive Lossless floating-Point) decompression.
@@ -88,20 +90,33 @@ where
8890
// Load kernel function
8991
let kernel_ptypes = [A::ALPInt::PTYPE, A::PTYPE];
9092
let cuda_function = ctx.load_function_ptype("alp", &kernel_ptypes)?;
91-
let mut launch_builder = ctx.launch_builder(&cuda_function);
93+
{
94+
let mut launch_builder = ctx.launch_builder(&cuda_function);
95+
96+
// Build launch args: input, output, f, e, length
97+
launch_builder.arg(&input_view);
98+
launch_builder.arg(&output_view);
99+
launch_builder.arg(&f);
100+
launch_builder.arg(&e);
101+
launch_builder.arg(&array_len_u64);
102+
103+
// Launch kernel
104+
let _cuda_events =
105+
launch_cuda_kernel_impl(&mut launch_builder, CU_EVENT_DISABLE_TIMING, array_len)?;
106+
}
92107

93-
// Build launch args: input, output, f, e, length
94-
launch_builder.arg(&input_view);
95-
launch_builder.arg(&output_view);
96-
launch_builder.arg(&f);
97-
launch_builder.arg(&e);
98-
launch_builder.arg(&array_len_u64);
108+
// Check if there are any patches to decode here
109+
let output_buf = if let Some(patches) = array.patches() {
110+
match_each_unsigned_integer_ptype!(patches.indices_ptype()?, |I| {
111+
execute_patches::<A, I>(patches.clone(), output_buf, ctx).await?
112+
})
113+
} else {
114+
output_buf
115+
};
99116

100-
// Launch kernel
101-
let _cuda_events =
102-
launch_cuda_kernel_impl(&mut launch_builder, CU_EVENT_DISABLE_TIMING, array_len)?;
117+
// TODO(aduffy): scatter patch values validity. There are several places we'll need to start
118+
// handling validity.
103119

104-
// Build result with newly allocated buffer
105120
let output_handle = BufferHandle::new_device(Arc::new(output_buf));
106121
Ok(Canonical::Primitive(PrimitiveArray::from_buffer_handle(
107122
output_handle,
@@ -117,8 +132,10 @@ mod tests {
117132
use vortex_array::IntoArray;
118133
use vortex_array::arrays::PrimitiveArray;
119134
use vortex_array::assert_arrays_eq;
120-
use vortex_array::validity::Validity::NonNullable;
135+
use vortex_array::patches::Patches;
136+
use vortex_array::validity::Validity;
121137
use vortex_buffer::Buffer;
138+
use vortex_buffer::buffer;
122139
use vortex_error::VortexExpect;
123140
use vortex_session::VortexSession;
124141

@@ -138,13 +155,24 @@ mod tests {
138155
let encoded_data: Vec<i32> = vec![100, 200, 300, 400, 500];
139156
let exponents = Exponents { e: 0, f: 2 }; // multiply by 100
140157

158+
// Patches
159+
let patches = Patches::new(
160+
5,
161+
0,
162+
PrimitiveArray::new(buffer![0u32, 4u32], Validity::NonNullable).into_array(),
163+
PrimitiveArray::new(buffer![0.0f32, 999f32], Validity::NonNullable).into_array(),
164+
None,
165+
)
166+
.unwrap();
167+
141168
let alp_array = ALPArray::try_new(
142-
PrimitiveArray::new(Buffer::from(encoded_data.clone()), NonNullable).into_array(),
169+
PrimitiveArray::new(Buffer::from(encoded_data.clone()), Validity::NonNullable)
170+
.into_array(),
143171
exponents,
144-
None,
172+
Some(patches),
145173
)?;
146174

147-
let cpu_result = alp_array.to_canonical()?;
175+
let cpu_result = alp_array.to_canonical()?.into_array();
148176

149177
let gpu_result = ALPExecutor
150178
.execute(alp_array.to_array(), &mut cuda_ctx)
@@ -154,7 +182,7 @@ mod tests {
154182
.await?
155183
.into_array();
156184

157-
assert_arrays_eq!(cpu_result.into_array(), gpu_result);
185+
assert_arrays_eq!(cpu_result, gpu_result);
158186

159187
Ok(())
160188
}

vortex-cuda/src/kernel/encodings/bitpacked.rs

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use std::fmt::Debug;
5+
use std::sync::Arc;
56

67
use async_trait::async_trait;
78
use cudarc::driver::DeviceRepr;
@@ -16,6 +17,8 @@ use vortex_array::buffer::DeviceBufferExt;
1617
use vortex_cuda_macros::cuda_tests;
1718
use vortex_dtype::NativePType;
1819
use vortex_dtype::match_each_integer_ptype;
20+
use vortex_dtype::match_each_unsigned_integer_ptype;
21+
use vortex_error::VortexExpect;
1922
use vortex_error::VortexResult;
2023
use vortex_error::vortex_ensure;
2124
use vortex_error::vortex_err;
@@ -29,6 +32,7 @@ use crate::CudaDeviceBuffer;
2932
use crate::executor::CudaExecute;
3033
use crate::executor::CudaExecutionCtx;
3134
use crate::kernel::launch_cuda_kernel_with_config;
35+
use crate::kernel::patches::execute_patches;
3236

3337
/// CUDA decoder for ALP (Adaptive Lossless floating-Point) decompression.
3438
#[derive(Debug)]
@@ -74,7 +78,6 @@ where
7478
} = array.into_parts();
7579

7680
vortex_ensure!(len > 0, "Non empty array");
77-
vortex_ensure!(patches.is_none(), "Patches not supported");
7881
let offset = offset as usize;
7982

8083
let device_input: BufferHandle = if packed.is_on_device() {
@@ -97,27 +100,46 @@ where
97100
let thread_count = if bits == 64 { 16 } else { 32 };
98101
let suffixes: [&str; _] = [&format!("{bit_width}bw"), &format!("{thread_count}t")];
99102
let cuda_function = ctx.load_function(&format!("bit_unpack_{}", bits), &suffixes)?;
100-
let mut launch_builder = ctx.launch_builder(&cuda_function);
101103

102-
// Build launch args: input, output, f, e, length
103-
launch_builder.arg(&input_view);
104-
launch_builder.arg(&output_view);
104+
{
105+
let mut launch_builder = ctx.launch_builder(&cuda_function);
105106

106-
let num_blocks = u32::try_from(len.div_ceil(1024))?;
107+
// Build launch args: input, output, f, e, length
108+
launch_builder.arg(&input_view);
109+
launch_builder.arg(&output_view);
107110

108-
let config = LaunchConfig {
109-
grid_dim: (num_blocks, 1, 1),
110-
block_dim: (thread_count, 1, 1),
111-
shared_mem_bytes: 0,
112-
};
111+
let num_blocks = u32::try_from(len.div_ceil(1024))?;
112+
113+
let config = LaunchConfig {
114+
grid_dim: (num_blocks, 1, 1),
115+
block_dim: (thread_count, 1, 1),
116+
shared_mem_bytes: 0,
117+
};
113118

114-
// Launch kernel
115-
let _cuda_events =
116-
launch_cuda_kernel_with_config(&mut launch_builder, config, CU_EVENT_DISABLE_TIMING)?;
119+
// Launch kernel
120+
let _cuda_events =
121+
launch_cuda_kernel_with_config(&mut launch_builder, config, CU_EVENT_DISABLE_TIMING)?;
122+
}
123+
124+
let output_handle = match patches {
125+
None => BufferHandle::new_device(output_buf.slice_typed::<A>(offset..(offset + len))),
126+
Some(p) => {
127+
let output_buf = output_buf.slice_typed::<A>(offset..(offset + len));
128+
let buf = output_buf
129+
.as_any()
130+
.downcast_ref::<CudaDeviceBuffer>()
131+
.vortex_expect("we created this as CudaDeviceBuffer")
132+
.clone();
133+
134+
let patched_buf = match_each_unsigned_integer_ptype!(p.indices_ptype()?, |I| {
135+
execute_patches::<A, I>(p, buf, ctx).await?
136+
});
137+
138+
BufferHandle::new_device(Arc::new(patched_buf))
139+
}
140+
};
117141

118142
// Build result with newly allocated buffer
119-
let output_handle =
120-
BufferHandle::new_device(output_buf.slice_typed::<A>(offset..(offset + len)));
121143
Ok(Canonical::Primitive(PrimitiveArray::from_buffer_handle(
122144
output_handle,
123145
A::PTYPE,
@@ -141,6 +163,34 @@ mod tests {
141163
use crate::CanonicalCudaExt;
142164
use crate::session::CudaSession;
143165

166+
#[test]
167+
fn test_patches() -> VortexResult<()> {
168+
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
169+
.vortex_expect("failed to create execution context");
170+
171+
let array = PrimitiveArray::new((0u16..=513).collect::<Buffer<_>>(), NonNullable);
172+
173+
// Last two items should be patched
174+
let bp_with_patches = BitPackedArray::encode(array.as_ref(), 9)?;
175+
assert!(bp_with_patches.patches().is_some());
176+
177+
let cpu_result = bp_with_patches.to_canonical()?.into_array();
178+
179+
let gpu_result = block_on(async {
180+
BitPackedExecutor
181+
.execute(bp_with_patches.to_array(), &mut cuda_ctx)
182+
.await
183+
.vortex_expect("GPU decompression failed")
184+
.into_host()
185+
.await
186+
.map(|a| a.into_array())
187+
})?;
188+
189+
assert_arrays_eq!(cpu_result, gpu_result);
190+
191+
Ok(())
192+
}
193+
144194
#[rstest]
145195
#[case::bw_1(1)]
146196
#[case::bw_2(2)]

0 commit comments

Comments
 (0)