Skip to content

Commit ff2a2ce

Browse files
committed
more test
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 5e0fbec commit ff2a2ce

File tree

8 files changed

+109
-16
lines changed

8 files changed

+109
-16
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-cuda/cudf-test/Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ version = { workspace = true }
2020
workspace = true
2121

2222
[dependencies]
23+
arrow-array = { workspace = true, features = ["ffi"] }
24+
arrow-schema = { workspace = true, features = ["ffi"] }
25+
futures = { workspace = true, features = ["executor"] }
26+
vortex-array = { workspace = true }
27+
vortex-buffer = { workspace = true }
28+
vortex-cuda = { path = "..", features = ["_test-harness"] }
29+
vortex-session = { workspace = true }
2330

2431
[build-dependencies]
2532
bindgen = { workspace = true }

vortex-cuda/cudf-test/cpp/cudf_arrow_ffi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ struct ArrowDeviceArray {
6161
int64_t device_id;
6262
ArrowDeviceType device_type;
6363
void* sync_event;
64+
int64_t reserved[3];
6465
};
6566

6667
// Error type: NULL on success, pointer to error string on failure.

vortex-cuda/cudf-test/src/lib.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,19 @@ impl Drop for CudfColumnView {
195195

196196
#[cfg(test)]
197197
mod tests {
198+
use arrow_array::ffi::FFI_ArrowSchema;
199+
use arrow_schema::DataType;
200+
use futures::executor::block_on;
201+
use vortex_array::Canonical;
202+
use vortex_array::IntoArray;
203+
use vortex_array::arrays::PrimitiveArray;
204+
use vortex_array::validity::Validity;
205+
use vortex_buffer::Buffer;
206+
use vortex_cuda::CudaSession;
207+
use vortex_cuda::arrow::CudaDeviceArrayExecute;
208+
use vortex_cuda::executor::CudaArrayExt;
209+
use vortex_session::VortexSession;
210+
198211
use super::*;
199212

200213
#[test]
@@ -210,4 +223,63 @@ mod tests {
210223
}
211224
}
212225
}
226+
227+
#[test]
228+
fn test_primitive_array_to_cudf_tableview() -> Result<()> {
229+
// Create a PrimitiveArray with 100 i64 values
230+
let data: Vec<i64> = (0..100).collect();
231+
let expected_len = data.len();
232+
let primitive_array =
233+
PrimitiveArray::new(Buffer::from(data), Validity::NonNullable).into_array();
234+
235+
// Create CUDA execution context
236+
let mut cuda_ctx = match CudaSession::create_execution_ctx(&VortexSession::empty()).unwrap();
237+
238+
// Export as ArrowDeviceArray using CudaDeviceArrayExecute
239+
let device_array = block_on(Canonical::execute(
240+
&primitive_array,
241+
primitive_array.clone(),
242+
&mut cuda_ctx,
243+
))
244+
.unwrap();
245+
246+
// Synchronize the CUDA stream to ensure the data is ready
247+
cuda_ctx.synchronize_stream().map_err(|e| CudfError {
248+
message: e.to_string(),
249+
})?;
250+
251+
// Create FFI_ArrowSchema from the data type
252+
let mut ffi_schema =
253+
FFI_ArrowSchema::try_from(&DataType::Int64).map_err(|e| CudfError {
254+
message: format!("Failed to create FFI schema: {}", e),
255+
})?;
256+
257+
// Create cudf context
258+
let cudf_ctx = CudfContext::new()?;
259+
260+
// Import into cudf tableview
261+
let tableview = unsafe {
262+
cudf_ctx.tableview_from_device(
263+
(&raw mut ffi_schema).cast::<ArrowSchema>(),
264+
(&raw const device_array).cast::<ArrowDeviceArray>(),
265+
)?
266+
};
267+
268+
// Verify row count
269+
let num_rows = tableview.num_rows()?;
270+
assert_eq!(num_rows, expected_len as i64, "Row count mismatch");
271+
println!(
272+
"Successfully imported PrimitiveArray into cudf tableview with {} rows",
273+
num_rows
274+
);
275+
276+
// Verify column count (should be 1 for a primitive array)
277+
let num_columns = tableview.num_columns()?;
278+
assert_eq!(num_columns, 1, "Column count mismatch");
279+
println!("Tableview has {} column(s)", num_columns);
280+
281+
// Tableview and cudf_ctx will be deallocated automatically via Drop
282+
283+
Ok(())
284+
}
213285
}

vortex-cuda/src/arrow/canonical.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use std::sync::Arc;
55

6+
use async_trait::async_trait;
67
use cudarc::driver::sys;
78
use vortex_array::ArrayRef;
89
use vortex_array::Canonical;
@@ -22,7 +23,7 @@ use crate::arrow::CudaPrivateData;
2223
use crate::arrow::DeviceType;
2324
use crate::executor::CudaArrayExt;
2425

25-
// Impl it for the execution context instead here...I think this is right?
26+
#[async_trait]
2627
impl CudaDeviceArrayExecute for Canonical {
2728
async fn execute(
2829
&self,
@@ -32,7 +33,7 @@ impl CudaDeviceArrayExecute for Canonical {
3233
let cuda_array = array.execute_cuda(ctx).await?;
3334

3435
let arrow_array = match cuda_array {
35-
Canonical::Primitive(primitive) => export_primitive(primitive, ctx)?,
36+
Canonical::Primitive(primitive) => export_primitive(primitive, ctx).await?,
3637
c => todo!("implement support for exporting {}", c.dtype()),
3738
};
3839

@@ -46,19 +47,25 @@ impl CudaDeviceArrayExecute for Canonical {
4647
}
4748
}
4849

49-
fn export_primitive(array: PrimitiveArray, ctx: &mut CudaExecutionCtx) -> VortexResult<ArrowArray> {
50+
async fn export_primitive(array: PrimitiveArray, ctx: &mut CudaExecutionCtx) -> VortexResult<ArrowArray> {
51+
unsafe extern "C" fn release(array: *mut ArrowArray) {
52+
// SAFETY: this is only safe if the caller provides a valid pointer to an `ArrowArray`.
53+
drop(unsafe { Box::from_raw(array) });
54+
}
55+
5056
let len = array.len();
5157
let PrimitiveArrayParts {
5258
buffer,
53-
ptype,
5459
validity,
5560
..
5661
} = array.into_parts();
5762

58-
unsafe extern "C" fn release(array: *mut ArrowArray) {
59-
// SAFETY: this is only safe if the caller provides a valid pointer to an `ArrowArray`.
60-
drop(unsafe { Box::from_raw(array) });
61-
}
63+
let buffer = if buffer.is_on_device() {
64+
buffer
65+
} else {
66+
// TODO(aduffy): I don't think this type parameter does anything
67+
ctx.move_to_device::<u8>(buffer)?.await?
68+
};
6269

6370
let null_count = match validity {
6471
Validity::NonNullable | Validity::AllValid => 0,
@@ -106,4 +113,3 @@ fn export_primitive(array: PrimitiveArray, ctx: &mut CudaExecutionCtx) -> Vortex
106113
private_data: Box::into_raw(private_data).cast(),
107114
})
108115
}
109-

vortex-cuda/src/arrow/mod.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@ use std::ffi::c_void;
1414
use std::ptr::NonNull;
1515
use std::sync::Arc;
1616

17+
use async_trait::async_trait;
1718
use cudarc::driver::CudaStream;
1819
use cudarc::driver::sys;
1920
use cudarc::runtime::sys::cudaEvent_t;
2021
use vortex_array::ArrayRef;
21-
use vortex_array::Executable;
2222
use vortex_array::buffer::BufferHandle;
2323
use vortex_error::VortexResult;
2424

2525
use crate::CudaExecutionCtx;
26-
use crate::executor::CudaArrayExt;
27-
use crate::executor::CudaExecute;
2826

2927
#[derive(Debug, Copy, Clone)]
3028
#[repr(i32)]
@@ -53,7 +51,7 @@ pub type SyncEvent = Option<NonNull<cudaEvent_t>>;
5351
/// event that the client must wait on.
5452
#[repr(C)]
5553
#[derive(Debug)]
56-
pub(crate) struct ArrowDeviceArray {
54+
pub struct ArrowDeviceArray {
5755
array: ArrowArray,
5856
device_id: i64,
5957
device_type: DeviceType,
@@ -85,6 +83,7 @@ pub(crate) struct ArrowArray {
8583
}
8684

8785
impl ArrowArray {
86+
#[allow(unused)]
8887
pub fn empty() -> Self {
8988
Self {
9089
length: 0,
@@ -101,6 +100,7 @@ impl ArrowArray {
101100
}
102101
}
103102

103+
#[expect(unused, reason = "cuda_stream and cuda_buffers need to have deferred drop")]
104104
pub(crate) struct CudaPrivateData {
105105
/// Hold a reference to the CudaStream so that it stays alive even after CudaExecutionCtx
106106
/// has been dropped.
@@ -113,7 +113,8 @@ pub(crate) struct CudaPrivateData {
113113
}
114114

115115
/// Trait implemented for types that can be exported to [`ArrowDeviceArray`].
116-
pub(crate) trait CudaDeviceArrayExecute {
116+
#[async_trait]
117+
pub trait CudaDeviceArrayExecute {
117118
async fn execute(
118119
&self,
119120
array: ArrayRef,

vortex-cuda/src/device_buffer.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use cudarc::driver::DevicePtr;
1212
use cudarc::driver::DeviceRepr;
1313
use cudarc::driver::sys;
1414
use futures::future::BoxFuture;
15-
use futures::future::ok;
1615
use vortex_array::buffer::BufferHandle;
1716
use vortex_array::buffer::DeviceBuffer;
1817
use vortex_buffer::Alignment;

vortex-cuda/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
use std::process::Command;
77

8-
mod arrow;
8+
pub mod arrow;
99
mod canonical;
1010
mod device_buffer;
1111
pub mod executor;

0 commit comments

Comments
 (0)