22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44use std:: fmt:: Debug ;
5+ use std:: sync:: Arc ;
56
67use async_trait:: async_trait;
78use cudarc:: driver:: DeviceRepr ;
@@ -16,6 +17,8 @@ use vortex_array::buffer::DeviceBufferExt;
1617use vortex_cuda_macros:: cuda_tests;
1718use vortex_dtype:: NativePType ;
1819use vortex_dtype:: match_each_integer_ptype;
20+ use vortex_dtype:: match_each_unsigned_integer_ptype;
21+ use vortex_error:: VortexExpect ;
1922use vortex_error:: VortexResult ;
2023use vortex_error:: vortex_ensure;
2124use vortex_error:: vortex_err;
@@ -29,6 +32,7 @@ use crate::CudaDeviceBuffer;
2932use crate :: executor:: CudaExecute ;
3033use crate :: executor:: CudaExecutionCtx ;
3134use crate :: kernel:: launch_cuda_kernel_with_config;
35+ use crate :: kernel:: patches:: execute_patches;
3236
3337/// CUDA decoder for ALP (Adaptive Lossless floating-Point) decompression.
3438#[ derive( Debug ) ]
7478 } = array. into_parts ( ) ;
7579
7680 vortex_ensure ! ( len > 0 , "Non empty array" ) ;
77- vortex_ensure ! ( patches. is_none( ) , "Patches not supported" ) ;
7881 let offset = offset as usize ;
7982
8083 let device_input: BufferHandle = if packed. is_on_device ( ) {
@@ -97,27 +100,46 @@ where
97100 let thread_count = if bits == 64 { 16 } else { 32 } ;
98101 let suffixes: [ & str ; _] = [ & format ! ( "{bit_width}bw" ) , & format ! ( "{thread_count}t" ) ] ;
99102 let cuda_function = ctx. load_function ( & format ! ( "bit_unpack_{}" , bits) , & suffixes) ?;
100- let mut launch_builder = ctx. launch_builder ( & cuda_function) ;
101103
102- // Build launch args: input, output, f, e, length
103- launch_builder. arg ( & input_view) ;
104- launch_builder. arg ( & output_view) ;
104+ {
105+ let mut launch_builder = ctx. launch_builder ( & cuda_function) ;
105106
106- let num_blocks = u32:: try_from ( len. div_ceil ( 1024 ) ) ?;
107+ // Build launch args: input, output, f, e, length
108+ launch_builder. arg ( & input_view) ;
109+ launch_builder. arg ( & output_view) ;
107110
108- let config = LaunchConfig {
109- grid_dim : ( num_blocks, 1 , 1 ) ,
110- block_dim : ( thread_count, 1 , 1 ) ,
111- shared_mem_bytes : 0 ,
112- } ;
111+ let num_blocks = u32:: try_from ( len. div_ceil ( 1024 ) ) ?;
112+
113+ let config = LaunchConfig {
114+ grid_dim : ( num_blocks, 1 , 1 ) ,
115+ block_dim : ( thread_count, 1 , 1 ) ,
116+ shared_mem_bytes : 0 ,
117+ } ;
113118
114- // Launch kernel
115- let _cuda_events =
116- launch_cuda_kernel_with_config ( & mut launch_builder, config, CU_EVENT_DISABLE_TIMING ) ?;
119+ // Launch kernel
120+ let _cuda_events =
121+ launch_cuda_kernel_with_config ( & mut launch_builder, config, CU_EVENT_DISABLE_TIMING ) ?;
122+ }
123+
124+ let output_handle = match patches {
125+ None => BufferHandle :: new_device ( output_buf. slice_typed :: < A > ( offset..( offset + len) ) ) ,
126+ Some ( p) => {
127+ let output_buf = output_buf. slice_typed :: < A > ( offset..( offset + len) ) ;
128+ let buf = output_buf
129+ . as_any ( )
130+ . downcast_ref :: < CudaDeviceBuffer > ( )
131+ . vortex_expect ( "we created this as CudaDeviceBuffer" )
132+ . clone ( ) ;
133+
134+ let patched_buf = match_each_unsigned_integer_ptype ! ( p. indices_ptype( ) ?, |I | {
135+ execute_patches:: <A , I >( p, buf, ctx) . await ?
136+ } ) ;
137+
138+ BufferHandle :: new_device ( Arc :: new ( patched_buf) )
139+ }
140+ } ;
117141
118142 // Build result with newly allocated buffer
119- let output_handle =
120- BufferHandle :: new_device ( output_buf. slice_typed :: < A > ( offset..( offset + len) ) ) ;
121143 Ok ( Canonical :: Primitive ( PrimitiveArray :: from_buffer_handle (
122144 output_handle,
123145 A :: PTYPE ,
@@ -141,6 +163,34 @@ mod tests {
141163 use crate :: CanonicalCudaExt ;
142164 use crate :: session:: CudaSession ;
143165
166+ #[ test]
167+ fn test_patches ( ) -> VortexResult < ( ) > {
168+ let mut cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) )
169+ . vortex_expect ( "failed to create execution context" ) ;
170+
171+ let array = PrimitiveArray :: new ( ( 0u16 ..=513 ) . collect :: < Buffer < _ > > ( ) , NonNullable ) ;
172+
173+ // Last two items should be patched
174+ let bp_with_patches = BitPackedArray :: encode ( array. as_ref ( ) , 9 ) ?;
175+ assert ! ( bp_with_patches. patches( ) . is_some( ) ) ;
176+
177+ let cpu_result = bp_with_patches. to_canonical ( ) ?. into_array ( ) ;
178+
179+ let gpu_result = block_on ( async {
180+ BitPackedExecutor
181+ . execute ( bp_with_patches. to_array ( ) , & mut cuda_ctx)
182+ . await
183+ . vortex_expect ( "GPU decompression failed" )
184+ . into_host ( )
185+ . await
186+ . map ( |a| a. into_array ( ) )
187+ } ) ?;
188+
189+ assert_arrays_eq ! ( cpu_result, gpu_result) ;
190+
191+ Ok ( ( ) )
192+ }
193+
144194 #[ rstest]
145195 #[ case:: bw_1( 1 ) ]
146196 #[ case:: bw_2( 2 ) ]
0 commit comments