Skip to content

Commit 469e31f

Browse files
authored
DF Test refactoring (#6223)
This PR attempts to make DF tests easier to read and write by moving much of the boilerplate into `TestSessionContext` and using more SQL (where its applicable). --------- Signed-off-by: Adam Gutglick <adam@spiraldb.com>
1 parent 848fd89 commit 469e31f

File tree

6 files changed

+374
-492
lines changed

6 files changed

+374
-492
lines changed

vortex-datafusion/src/lib.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ use vortex::expr::stats::Precision;
1111
mod convert;
1212
mod persistent;
1313

14+
#[cfg(test)]
15+
mod tests;
16+
1417
pub use convert::exprs::ExpressionConvertor;
1518
pub use persistent::*;
1619

@@ -46,3 +49,101 @@ where
4649
}
4750
}
4851
}
52+
53+
#[cfg(test)]
54+
mod common_tests {
55+
use std::sync::Arc;
56+
use std::sync::LazyLock;
57+
58+
use datafusion::arrow::array::RecordBatch;
59+
use datafusion::datasource::provider::DefaultTableFactory;
60+
use datafusion::execution::SessionStateBuilder;
61+
use datafusion::prelude::SessionContext;
62+
use datafusion_catalog::TableProvider;
63+
use datafusion_common::DFSchema;
64+
use datafusion_common::GetExt;
65+
use datafusion_expr::CreateExternalTable;
66+
use object_store::ObjectStore;
67+
use object_store::memory::InMemory;
68+
use url::Url;
69+
use vortex::VortexSessionDefault;
70+
use vortex::array::ArrayRef;
71+
use vortex::array::arrow::FromArrowArray;
72+
use vortex::file::WriteOptionsSessionExt;
73+
use vortex::io::ObjectStoreWriter;
74+
use vortex::io::VortexWrite;
75+
use vortex::session::VortexSession;
76+
77+
use crate::VortexFormatFactory;
78+
79+
static VX_SESSION: LazyLock<VortexSession> = LazyLock::new(VortexSession::default);
80+
81+
pub struct TestSessionContext {
82+
pub store: Arc<dyn ObjectStore>,
83+
pub session: SessionContext,
84+
}
85+
86+
impl Default for TestSessionContext {
87+
fn default() -> Self {
88+
let store = Arc::new(InMemory::new());
89+
let factory = Arc::new(VortexFormatFactory::new());
90+
let session_state_builder = SessionStateBuilder::new()
91+
.with_default_features()
92+
.with_table_factory(
93+
factory.get_ext().to_uppercase(),
94+
Arc::new(DefaultTableFactory::new()),
95+
)
96+
.with_file_formats(vec![factory])
97+
.with_object_store(&Url::try_from("file://").unwrap(), store.clone());
98+
99+
let session: SessionContext =
100+
SessionContext::new_with_state(session_state_builder.build()).enable_url_table();
101+
102+
Self { store, session }
103+
}
104+
}
105+
106+
impl TestSessionContext {
107+
// Write arrow data into a vortex file.
108+
pub async fn write_arrow_batch<P>(&self, path: P, batch: &RecordBatch) -> anyhow::Result<()>
109+
where
110+
P: Into<object_store::path::Path>,
111+
{
112+
let array = ArrayRef::from_arrow(batch, false)?;
113+
let mut write = ObjectStoreWriter::new(self.store.clone(), &path.into()).await?;
114+
VX_SESSION
115+
.write_options()
116+
.write(&mut write, array.to_array_stream())
117+
.await?;
118+
write.shutdown().await?;
119+
120+
Ok(())
121+
}
122+
123+
/// Creates a ListingTable provider targeted at the provided path
124+
pub async fn table_provider<S>(
125+
&self,
126+
name: &str,
127+
location: impl Into<String>,
128+
schema: S,
129+
) -> anyhow::Result<Arc<dyn TableProvider>>
130+
where
131+
DFSchema: TryFrom<S>,
132+
anyhow::Error: From<<S as TryInto<DFSchema>>::Error>,
133+
{
134+
let factory = self.session.table_factory("VORTEX").unwrap();
135+
136+
let cmd = CreateExternalTable::builder(
137+
name,
138+
location.into(),
139+
"vortex",
140+
DFSchema::try_from(schema)?.into(),
141+
)
142+
.build();
143+
144+
let table = factory.create(&self.session.state(), &cmd).await?;
145+
146+
Ok(table)
147+
}
148+
}
149+
}

vortex-datafusion/src/persistent/format.rs

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -503,59 +503,45 @@ impl FileFormat for VortexFormat {
503503

504504
#[cfg(test)]
505505
mod tests {
506-
use datafusion::execution::SessionStateBuilder;
507-
use datafusion::prelude::SessionContext;
508-
use tempfile::TempDir;
509506

510507
use super::*;
511-
use crate::persistent::register_vortex_format_factory;
508+
use crate::common_tests::TestSessionContext;
512509

513510
#[tokio::test]
514-
async fn create_table() {
515-
let dir = TempDir::new().unwrap();
511+
async fn create_table() -> anyhow::Result<()> {
512+
let ctx = TestSessionContext::default();
516513

517-
let factory: VortexFormatFactory = VortexFormatFactory::new();
518-
let mut session_state_builder = SessionStateBuilder::new().with_default_features();
519-
register_vortex_format_factory(factory, &mut session_state_builder);
520-
let session = SessionContext::new_with_state(session_state_builder.build());
521-
522-
let df = session
523-
.sql(&format!(
514+
ctx.session
515+
.sql(
524516
"CREATE EXTERNAL TABLE my_tbl \
525517
(c1 VARCHAR NOT NULL, c2 INT NOT NULL) \
526518
STORED AS vortex \
527-
LOCATION '{}'",
528-
dir.path().to_str().unwrap()
529-
))
530-
.await
531-
.unwrap();
519+
LOCATION 'table/'",
520+
)
521+
.await?;
532522

533-
assert_eq!(df.count().await.unwrap(), 0);
523+
assert!(ctx.session.table_exist("my_tbl")?);
524+
525+
Ok(())
534526
}
535527

536528
#[tokio::test]
537-
async fn configure_format_source() {
538-
let dir = TempDir::new().unwrap();
539-
540-
let factory = VortexFormatFactory::new();
541-
let mut session_state_builder = SessionStateBuilder::new().with_default_features();
542-
register_vortex_format_factory(factory, &mut session_state_builder);
543-
let session = SessionContext::new_with_state(session_state_builder.build());
529+
async fn configure_format_source() -> anyhow::Result<()> {
530+
let ctx = TestSessionContext::default();
544531

545-
session
546-
.sql(&format!(
532+
ctx.session
533+
.sql(
547534
"CREATE EXTERNAL TABLE my_tbl \
548535
(c1 VARCHAR NOT NULL, c2 INT NOT NULL) \
549536
STORED AS vortex \
550-
LOCATION '{}' \
537+
LOCATION 'table/' \
551538
OPTIONS( footer_initial_read_size_bytes '12345' );",
552-
dir.path().to_str().unwrap()
553-
))
554-
.await
555-
.unwrap()
539+
)
540+
.await?
556541
.collect()
557-
.await
558-
.unwrap();
542+
.await?;
543+
544+
Ok(())
559545
}
560546

561547
#[test]

0 commit comments

Comments
 (0)