Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vortex-array/src/arrays/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mod masked;
mod null;
mod primitive;
mod scalar_fn;
mod shared;
mod slice;
mod struct_;
mod varbin;
Expand All @@ -55,6 +56,7 @@ pub use masked::*;
pub use null::*;
pub use primitive::*;
pub use scalar_fn::*;
pub use shared::*;
pub use slice::*;
pub use struct_::*;
pub use varbin::*;
Expand Down
97 changes: 97 additions & 0 deletions vortex-array/src/arrays/shared/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::sync::Arc;

use parking_lot::RwLock;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;

use crate::ArrayRef;
use crate::Canonical;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::stats::ArrayStats;
use vortex_dtype::DType;

#[derive(Debug, Clone)]
pub struct SharedArray {
pub(super) state: Arc<RwLock<SharedState>>,
pub(super) dtype: DType,
pub(super) stats: ArrayStats,
}

#[derive(Debug, Clone)]
pub(super) enum SharedState {
Source(ArrayRef),
Cached(Canonical),
}

impl SharedArray {
pub fn new(source: ArrayRef) -> Self {
Self {
dtype: source.dtype().clone(),
state: Arc::new(RwLock::new(SharedState::Source(source))),
stats: ArrayStats::default(),
}
}

pub fn cached(&self) -> Option<Canonical> {
match &*self.state.read() {
SharedState::Cached(canonical) => Some(canonical.clone()),
SharedState::Source(_) => None,
}
}

pub fn cache_or_return(&self, canonical: Canonical) -> Canonical {
let mut state = self.state.write();
match &*state {
SharedState::Cached(existing) => existing.clone(),
SharedState::Source(_) => {
*state = SharedState::Cached(canonical.clone());
canonical
}
}
}

pub fn source_if_any(&self) -> Option<ArrayRef> {
match &*self.state.read() {
SharedState::Source(source) => Some(source.clone()),
SharedState::Cached(_) => None,
}
}

pub(super) fn canonicalize(&self, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
if let Some(existing) = self.cached() {
return Ok(existing);
}
let source = match self.source_if_any() {
Some(source) => source,
None => {
return Ok(
self.cached()
.vortex_expect("cache present when no source"),
)
}
};
let canonical = source.execute::<Canonical>(ctx)?;
Ok(self.cache_or_return(canonical))
}

pub(super) fn current_array_ref(&self) -> ArrayRef {
match &*self.state.read() {
SharedState::Source(source) => source.clone(),
SharedState::Cached(canonical) => canonical.clone().into_array(),
}
}

pub(super) fn set_source(&mut self, source: ArrayRef) {
self.dtype = source.dtype().clone();
*self.state.write() = SharedState::Source(source);
}

pub(super) fn visit_children(&self, visitor: &mut dyn crate::ArrayChildVisitor) {
let child = self.current_array_ref();
visitor.visit_child("source", &child);
}
}
11 changes: 11 additions & 0 deletions vortex-array/src/arrays/shared/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

mod array;
mod vtable;

pub use array::SharedArray;
pub use vtable::SharedVTable;

#[cfg(test)]
mod tests;
34 changes: 34 additions & 0 deletions vortex-array/src/arrays/shared/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_buffer::buffer;
use vortex_session::VortexSession;

use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::SharedArray;
use crate::hash::ArrayEq;
use crate::hash::Precision as HashPrecision;
use crate::session::ArraySession;
use crate::validity::Validity;

#[test]
fn shared_array_caches_on_canonicalize() -> vortex_error::VortexResult<()> {
let array = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::NonNullable).into_array();
let shared = SharedArray::new(array);

assert!(shared.cached().is_none());

let session = VortexSession::empty().with::<ArraySession>();
let mut ctx = ExecutionCtx::new(session);

let first = shared.canonicalize(&mut ctx)?;
let cached = shared.cached().expect("canonicalize should cache result");
assert!(cached.as_ref().array_eq(first.as_ref(), HashPrecision::Value));

let second = shared.canonicalize(&mut ctx)?;
assert!(second.as_ref().array_eq(first.as_ref(), HashPrecision::Value));

Ok(())
}
148 changes: 148 additions & 0 deletions vortex-array/src/arrays/shared/vtable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::hash::Hash;
use std::ops::Range;

use vortex_dtype::DType;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::ArrayBufferVisitor;
use crate::ArrayChildVisitor;
use crate::ArrayRef;
use crate::Canonical;
use crate::EmptyMetadata;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::Precision;
use crate::arrays::shared::SharedArray;
use crate::hash::ArrayEq;
use crate::hash::ArrayHash;
use crate::stats::StatsSetRef;
use crate::validity::Validity;
use crate::vtable;
use crate::vtable::ArrayId;
use crate::vtable::BaseArrayVTable;
use crate::vtable::NotSupported;
use crate::vtable::OperationsVTable;
use crate::vtable::VTable;
use crate::vtable::ValidityVTable;
use crate::vtable::VisitorVTable;

vtable!(Shared);

#[derive(Debug)]
pub struct SharedVTable;

impl SharedVTable {
pub const ID: ArrayId = ArrayId::new_ref("vortex.shared");
}

impl VTable for SharedVTable {
type Array = SharedArray;
type Metadata = EmptyMetadata;

type ArrayVTable = Self;
type OperationsVTable = Self;
type ValidityVTable = Self;
type VisitorVTable = Self;
type ComputeVTable = NotSupported;

fn id(_array: &Self::Array) -> ArrayId {
Self::ID
}

fn metadata(_array: &Self::Array) -> VortexResult<Self::Metadata> {
Ok(EmptyMetadata)
}

fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
vortex_error::vortex_bail!("Shared array is not serializable")
}

fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
vortex_error::vortex_bail!("Shared array is not serializable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

}

fn build(
dtype: &DType,
len: usize,
_metadata: &Self::Metadata,
_buffers: &[crate::buffer::BufferHandle],
children: &dyn crate::serde::ArrayChildren,
) -> VortexResult<SharedArray> {
let child = children.get(0, dtype, len)?;
Ok(SharedArray::new(child))
}

fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
vortex_error::vortex_ensure!(
children.len() == 1,
"SharedArray expects exactly 1 child, got {}",
children.len()
);
let child = children
.into_iter()
.next()
.vortex_expect("children length already validated");
array.set_source(child);
Ok(())
}

fn canonicalize(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
array.canonicalize(ctx)
}

fn slice(array: &Self::Array, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
let sliced = array.current_array_ref().slice(range)?;
Ok(Some(SharedArray::new(sliced).into_array()))
}
}

impl BaseArrayVTable<SharedVTable> for SharedVTable {
fn len(array: &SharedArray) -> usize {
array.current_array_ref().len()
}

fn dtype(array: &SharedArray) -> &DType {
&array.dtype
}

fn stats(array: &SharedArray) -> StatsSetRef<'_> {
array.stats.to_ref(array.as_ref())
}

fn array_hash<H: std::hash::Hasher>(array: &SharedArray, state: &mut H, precision: Precision) {
let current = array.current_array_ref();
current.array_hash(state, precision);
array.dtype.hash(state);
}

fn array_eq(array: &SharedArray, other: &SharedArray, precision: Precision) -> bool {
let current = array.current_array_ref();
let other_current = other.current_array_ref();
current.array_eq(&other_current, precision) && array.dtype == other.dtype
}
}

impl OperationsVTable<SharedVTable> for SharedVTable {
fn scalar_at(array: &SharedArray, index: usize) -> VortexResult<Scalar> {
array.current_array_ref().scalar_at(index)
}
}

impl ValidityVTable<SharedVTable> for SharedVTable {
fn validity(array: &SharedArray) -> VortexResult<Validity> {
array.current_array_ref().validity()
}
}

impl VisitorVTable<SharedVTable> for SharedVTable {
fn visit_buffers(_array: &SharedArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &SharedArray, visitor: &mut dyn ArrayChildVisitor) {
array.visit_children(visitor);
}
}
2 changes: 2 additions & 0 deletions vortex-cuda/src/kernel/arrays/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

mod constant;
mod dict;
mod shared;

pub use constant::ConstantNumericExecutor;
pub use dict::DictExecutor;
pub use shared::SharedExecutor;
41 changes: 41 additions & 0 deletions vortex-cuda/src/kernel/arrays/shared.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use async_trait::async_trait;
use vortex_array::ArrayRef;
use vortex_array::Canonical;
use vortex_array::arrays::SharedVTable;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;

use crate::executor::CudaArrayExt;
use crate::executor::CudaExecute;
use crate::executor::CudaExecutionCtx;

/// CUDA executor for SharedArray.
#[derive(Debug)]
pub struct SharedExecutor;

#[async_trait]
impl CudaExecute for SharedExecutor {
async fn execute(
&self,
array: ArrayRef,
ctx: &mut CudaExecutionCtx,
) -> VortexResult<Canonical> {
let shared = array
.try_into::<SharedVTable>()
.ok()
.vortex_expect("Array is not a Shared array");

if let Some(cached) = shared.cached() {
return Ok(cached);
}

let source = shared
.source_if_any()
.vortex_expect("not cached shared array has to have the source array");
let canonical = source.execute_cuda(ctx).await?;
Ok(shared.cache_or_return(canonical))
}
}
1 change: 1 addition & 0 deletions vortex-cuda/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod slice;

pub use arrays::ConstantNumericExecutor;
pub use arrays::DictExecutor;
pub use arrays::SharedExecutor;
pub use encodings::*;
pub use filter::FilterExecutor;
pub use slice::SliceExecutor;
Expand Down
3 changes: 3 additions & 0 deletions vortex-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use kernel::DictExecutor;
use kernel::FilterExecutor;
use kernel::FoRExecutor;
use kernel::RunEndExecutor;
use kernel::SharedExecutor;
use kernel::ZigZagExecutor;
use kernel::ZstdExecutor;
pub use kernel::ZstdKernelPrep;
Expand All @@ -40,6 +41,7 @@ use vortex_alp::ALPVTable;
use vortex_array::arrays::ConstantVTable;
use vortex_array::arrays::DictVTable;
use vortex_array::arrays::FilterVTable;
use vortex_array::arrays::SharedVTable;
use vortex_array::arrays::SliceVTable;
use vortex_decimal_byte_parts::DecimalBytePartsVTable;
use vortex_fastlanes::BitPackedVTable;
Expand Down Expand Up @@ -69,6 +71,7 @@ pub fn initialize_cuda(session: &CudaSession) {
session.register_kernel(ConstantVTable::ID, &ConstantNumericExecutor);
session.register_kernel(DecimalBytePartsVTable::ID, &DecimalBytePartsExecutor);
session.register_kernel(DictVTable::ID, &DictExecutor);
session.register_kernel(SharedVTable::ID, &SharedExecutor);
session.register_kernel(FoRVTable::ID, &FoRExecutor);
session.register_kernel(RunEndVTable::ID, &RunEndExecutor);
session.register_kernel(SequenceVTable::ID, &SequenceExecutor);
Expand Down
Loading
Loading