Skip to content

Commit ac67ae4

Browse files
authored
Ensure null inputs to array setop functions return null output (#19683)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #19682 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Explained in issue. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Change array_except, array_intersect and array_union UDFs to return null if either input is null. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added & fixed tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> Behaviour change to a function output. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent ab81d3b commit ac67ae4

File tree

4 files changed

+115
-115
lines changed

4 files changed

+115
-115
lines changed

datafusion/functions-nested/src/except.rs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! [`ScalarUDFImpl`] definitions for array_except function.
18+
//! [`ScalarUDFImpl`] definition for array_except function.
1919
2020
use crate::utils::{check_datatypes, make_scalar_function};
21+
use arrow::array::new_null_array;
2122
use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait, cast::AsArray};
22-
use arrow::buffer::OffsetBuffer;
23+
use arrow::buffer::{NullBuffer, OffsetBuffer};
2324
use arrow::datatypes::{DataType, FieldRef};
2425
use arrow::row::{RowConverter, SortField};
2526
use datafusion_common::utils::{ListCoercion, take_function_args};
@@ -28,6 +29,7 @@ use datafusion_expr::{
2829
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2930
};
3031
use datafusion_macros::user_doc;
32+
use itertools::Itertools;
3133
use std::any::Any;
3234
use std::sync::Arc;
3335

@@ -104,8 +106,11 @@ impl ScalarUDFImpl for ArrayExcept {
104106
}
105107

106108
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
107-
match (&arg_types[0].clone(), &arg_types[1].clone()) {
108-
(DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()),
109+
match (&arg_types[0], &arg_types[1]) {
110+
(DataType::Null, DataType::Null) => {
111+
Ok(DataType::new_list(DataType::Null, true))
112+
}
113+
(DataType::Null, dt) | (dt, DataType::Null) => Ok(dt.clone()),
109114
(dt, _) => Ok(dt.clone()),
110115
}
111116
}
@@ -129,8 +134,16 @@ impl ScalarUDFImpl for ArrayExcept {
129134
fn array_except_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
130135
let [array1, array2] = take_function_args("array_except", args)?;
131136

137+
let len = array1.len();
132138
match (array1.data_type(), array2.data_type()) {
133-
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),
139+
(DataType::Null, DataType::Null) => Ok(new_null_array(
140+
&DataType::new_list(DataType::Null, true),
141+
len,
142+
)),
143+
(DataType::Null, dt @ DataType::List(_))
144+
| (DataType::Null, dt @ DataType::LargeList(_))
145+
| (dt @ DataType::List(_), DataType::Null)
146+
| (dt @ DataType::LargeList(_), DataType::Null) => Ok(new_null_array(dt, len)),
134147
(DataType::List(field), DataType::List(_)) => {
135148
check_datatypes("array_except", &[array1, array2])?;
136149
let list1 = array1.as_list::<i32>();
@@ -169,15 +182,27 @@ fn general_except<OffsetSize: OffsetSizeTrait>(
169182
let mut rows = Vec::with_capacity(l_values.num_rows());
170183
let mut dedup = HashSet::new();
171184

172-
for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) {
173-
let l_slice = l_w[0].as_usize()..l_w[1].as_usize();
174-
let r_slice = r_w[0].as_usize()..r_w[1].as_usize();
175-
for i in r_slice {
176-
let right_row = r_values.row(i);
185+
let nulls = NullBuffer::union(l.nulls(), r.nulls());
186+
187+
let l_offsets_iter = l.offsets().iter().tuple_windows();
188+
let r_offsets_iter = r.offsets().iter().tuple_windows();
189+
for (list_index, ((l_start, l_end), (r_start, r_end))) in
190+
l_offsets_iter.zip(r_offsets_iter).enumerate()
191+
{
192+
if nulls
193+
.as_ref()
194+
.is_some_and(|nulls| nulls.is_null(list_index))
195+
{
196+
offsets.push(OffsetSize::usize_as(rows.len()));
197+
continue;
198+
}
199+
200+
for element_index in r_start.as_usize()..r_end.as_usize() {
201+
let right_row = r_values.row(element_index);
177202
dedup.insert(right_row);
178203
}
179-
for i in l_slice {
180-
let left_row = l_values.row(i);
204+
for element_index in l_start.as_usize()..l_end.as_usize() {
205+
let left_row = l_values.row(element_index);
181206
if dedup.insert(left_row) {
182207
rows.push(left_row);
183208
}
@@ -192,7 +217,7 @@ fn general_except<OffsetSize: OffsetSizeTrait>(
192217
field.to_owned(),
193218
OffsetBuffer::new(offsets.into()),
194219
values.to_owned(),
195-
l.nulls().cloned(),
220+
nulls,
196221
))
197222
} else {
198223
internal_err!("array_except failed to convert rows")

datafusion/functions-nested/src/set_ops.rs

Lines changed: 27 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
2020
use crate::utils::make_scalar_function;
2121
use arrow::array::{
22-
Array, ArrayRef, GenericListArray, LargeListArray, ListArray, OffsetSizeTrait,
23-
new_null_array,
22+
Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_empty_array, new_null_array,
2423
};
2524
use arrow::buffer::{NullBuffer, OffsetBuffer};
2625
use arrow::compute;
@@ -69,7 +68,7 @@ make_udf_expr_and_func!(
6968

7069
#[user_doc(
7170
doc_section(label = "Array Functions"),
72-
description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.",
71+
description = "Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates.",
7372
syntax_example = "array_union(array1, array2)",
7473
sql_example = r#"```sql
7574
> select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
@@ -136,8 +135,7 @@ impl ScalarUDFImpl for ArrayUnion {
136135
let [array1, array2] = take_function_args(self.name(), arg_types)?;
137136
match (array1, array2) {
138137
(Null, Null) => Ok(DataType::new_list(Null, true)),
139-
(Null, dt) => Ok(dt.clone()),
140-
(dt, Null) => Ok(dt.clone()),
138+
(Null, dt) | (dt, Null) => Ok(dt.clone()),
141139
(dt, _) => Ok(dt.clone()),
142140
}
143141
}
@@ -221,8 +219,7 @@ impl ScalarUDFImpl for ArrayIntersect {
221219
let [array1, array2] = take_function_args(self.name(), arg_types)?;
222220
match (array1, array2) {
223221
(Null, Null) => Ok(DataType::new_list(Null, true)),
224-
(Null, dt) => Ok(dt.clone()),
225-
(dt, Null) => Ok(dt.clone()),
222+
(Null, dt) | (dt, Null) => Ok(dt.clone()),
226223
(dt, _) => Ok(dt.clone()),
227224
}
228225
}
@@ -363,23 +360,19 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
363360

364361
let mut offsets = vec![OffsetSize::usize_as(0)];
365362
let mut new_arrays = vec![];
366-
let mut new_null_buf = vec![];
367363
let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
368-
for (first_arr, second_arr) in l.iter().zip(r.iter()) {
369-
let mut ele_should_be_null = false;
364+
for (l_arr, r_arr) in l.iter().zip(r.iter()) {
365+
let last_offset = *offsets.last().unwrap();
370366

371-
let l_values = if let Some(first_arr) = first_arr {
372-
converter.convert_columns(&[first_arr])?
373-
} else {
374-
ele_should_be_null = true;
375-
converter.empty_rows(0, 0)
376-
};
377-
378-
let r_values = if let Some(second_arr) = second_arr {
379-
converter.convert_columns(&[second_arr])?
380-
} else {
381-
ele_should_be_null = true;
382-
converter.empty_rows(0, 0)
367+
let (l_values, r_values) = match (l_arr, r_arr) {
368+
(Some(l_arr), Some(r_arr)) => (
369+
converter.convert_columns(&[l_arr])?,
370+
converter.convert_columns(&[r_arr])?,
371+
),
372+
_ => {
373+
offsets.push(last_offset);
374+
continue;
375+
}
383376
};
384377

385378
let l_iter = l_values.iter().sorted().dedup();
@@ -405,11 +398,6 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
405398
}
406399
}
407400

408-
let last_offset = match offsets.last() {
409-
Some(offset) => *offset,
410-
None => return internal_err!("offsets should not be empty"),
411-
};
412-
413401
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
414402
let arrays = converter.convert_rows(rows)?;
415403
let array = match arrays.first() {
@@ -419,18 +407,21 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
419407
}
420408
};
421409

422-
new_null_buf.push(!ele_should_be_null);
423410
new_arrays.push(array);
424411
}
425412

426413
let offsets = OffsetBuffer::new(offsets.into());
427414
let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect();
428-
let values = compute::concat(&new_arrays_ref)?;
415+
let values = if new_arrays_ref.is_empty() {
416+
new_empty_array(&l.value_type())
417+
} else {
418+
compute::concat(&new_arrays_ref)?
419+
};
429420
let arr = GenericListArray::<OffsetSize>::try_new(
430421
field,
431422
offsets,
432423
values,
433-
Some(NullBuffer::new(new_null_buf.into())),
424+
NullBuffer::union(l.nulls(), r.nulls()),
434425
)?;
435426
Ok(Arc::new(arr))
436427
}
@@ -440,59 +431,13 @@ fn general_set_op(
440431
array2: &ArrayRef,
441432
set_op: SetOp,
442433
) -> Result<ArrayRef> {
443-
fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result<ArrayRef> {
444-
let field = Arc::new(Field::new_list_field(data_type.clone(), true));
445-
let values = new_null_array(data_type, len);
446-
if large {
447-
Ok(Arc::new(LargeListArray::try_new(
448-
field,
449-
OffsetBuffer::new_zeroed(len),
450-
values,
451-
None,
452-
)?))
453-
} else {
454-
Ok(Arc::new(ListArray::try_new(
455-
field,
456-
OffsetBuffer::new_zeroed(len),
457-
values,
458-
None,
459-
)?))
460-
}
461-
}
462-
434+
let len = array1.len();
463435
match (array1.data_type(), array2.data_type()) {
464-
(Null, Null) => Ok(Arc::new(ListArray::new_null(
465-
Arc::new(Field::new_list_field(Null, true)),
466-
array1.len(),
467-
))),
468-
(Null, List(field)) => {
469-
if set_op == SetOp::Intersect {
470-
return empty_array(field.data_type(), array1.len(), false);
471-
}
472-
let array = as_list_array(&array2)?;
473-
general_array_distinct::<i32>(array, field)
474-
}
475-
(List(field), Null) => {
476-
if set_op == SetOp::Intersect {
477-
return empty_array(field.data_type(), array1.len(), false);
478-
}
479-
let array = as_list_array(&array1)?;
480-
general_array_distinct::<i32>(array, field)
481-
}
482-
(Null, LargeList(field)) => {
483-
if set_op == SetOp::Intersect {
484-
return empty_array(field.data_type(), array1.len(), true);
485-
}
486-
let array = as_large_list_array(&array2)?;
487-
general_array_distinct::<i64>(array, field)
488-
}
489-
(LargeList(field), Null) => {
490-
if set_op == SetOp::Intersect {
491-
return empty_array(field.data_type(), array1.len(), true);
492-
}
493-
let array = as_large_list_array(&array1)?;
494-
general_array_distinct::<i64>(array, field)
495-
}
436+
(Null, Null) => Ok(new_null_array(&DataType::new_list(Null, true), len)),
437+
(Null, dt @ List(_))
438+
| (Null, dt @ LargeList(_))
439+
| (dt @ List(_), Null)
440+
| (dt @ LargeList(_), Null) => Ok(new_null_array(dt, len)),
496441
(List(field), List(_)) => {
497442
let array1 = as_list_array(&array1)?;
498443
let array2 = as_list_array(&array2)?;

0 commit comments

Comments
 (0)