Skip to content
Draft
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 231 additions & 2 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

mod literal_lookup_table;

use super::{Column, Literal};
use super::{CastExpr, Column, Literal};
use crate::PhysicalExpr;
use crate::expressions::{lit, try_cast};
use crate::expressions::{BinaryExpr, lit, try_cast};
use arrow::array::*;
use arrow::compute::kernels::zip::zip;
use arrow::compute::{
Expand All @@ -33,6 +33,7 @@ use datafusion_common::{
internal_datafusion_err, internal_err,
};
use datafusion_expr::ColumnarValue;
use datafusion_expr_common::operator::Operator;
use indexmap::{IndexMap, IndexSet};
use std::borrow::Cow;
use std::hash::Hash;
Expand Down Expand Up @@ -81,6 +82,14 @@ enum EvalMethod {
///
/// See [`LiteralLookupTable`] for more details
WithExprScalarLookupTable(LiteralLookupTable),

/// This is a specialization for divide-by-zero protection pattern:
/// CASE WHEN y > 0 THEN x / y ELSE NULL END
/// CASE WHEN y != 0 THEN x / y ELSE NULL END
///
/// Instead of evaluating the full CASE expression, it is preferred to directly perform division
/// that return NULL when the divisor is zero.
DivideByZeroProtection,
}

/// Implementing hash so we can use `derive` on [`EvalMethod`].
Expand Down Expand Up @@ -647,6 +656,20 @@ impl CaseExpr {
return Ok(EvalMethod::WithExpression(body.project()?));
}

// Check for divide-by-zero protection pattern:
// CASE WHEN y > 0 THEN x / y ELSE NULL END
if body.when_then_expr.len() == 1 && body.else_expr.is_none() {
let (when_expr, then_expr) = &body.when_then_expr[0];

if let Some(checked_operand) = Self::extract_non_zero_operand(when_expr)
&& let Some((_numerator, divisor)) =
Self::extract_division_operands(then_expr)
&& divisor.eq(&checked_operand)
{
return Ok(EvalMethod::DivideByZeroProtection);
}
}

Ok(
if body.when_then_expr.len() == 1
&& is_cheap_and_infallible(&(body.when_then_expr[0].1))
Expand Down Expand Up @@ -681,6 +704,67 @@ impl CaseExpr {
pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.body.else_expr.as_ref()
}

/// Extract the operand being checked for non-zero from a comparison expression.
/// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`.
fn extract_non_zero_operand(
expr: &Arc<dyn PhysicalExpr>,
) -> Option<Arc<dyn PhysicalExpr>> {
let binary = expr.as_any().downcast_ref::<BinaryExpr>()?;

match binary.op() {
// y > 0 or y != 0
Operator::Gt | Operator::NotEq if Self::is_literal_zero(binary.right()) => {
Some(Arc::clone(binary.left()))
}
// 0 < y or 0 != y
Operator::Lt | Operator::NotEq if Self::is_literal_zero(binary.left()) => {
Some(Arc::clone(binary.right()))
}
_ => None,
}
}

/// Extract (numerator, divisor) from a division expression.
fn extract_division_operands(
expr: &Arc<dyn PhysicalExpr>,
) -> Option<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> {
let binary = expr.as_any().downcast_ref::<BinaryExpr>()?;

if binary.op() == &Operator::Divide {
let divisor =
if let Some(cast) = binary.right().as_any().downcast_ref::<CastExpr>() {
Arc::clone(cast.expr())
} else {
Arc::clone(binary.right())
};
Some((Arc::clone(binary.left()), divisor))
} else {
None
}
}

/// Check if an expression is a literal zero value
fn is_literal_zero(expr: &Arc<dyn PhysicalExpr>) -> bool {
if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
match lit.value() {
ScalarValue::Int8(Some(0))
| ScalarValue::Int16(Some(0))
| ScalarValue::Int32(Some(0))
| ScalarValue::Int64(Some(0))
| ScalarValue::UInt8(Some(0))
| ScalarValue::UInt16(Some(0))
| ScalarValue::UInt32(Some(0))
| ScalarValue::UInt64(Some(0)) => true,
ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true,
ScalarValue::Float32(Some(v)) if *v == 0.0 => true,
ScalarValue::Float64(Some(v)) if *v == 0.0 => true,
_ => false,
}
} else {
false
}
}
}

impl CaseBody {
Expand Down Expand Up @@ -1170,6 +1254,19 @@ impl CaseExpr {

Ok(result)
}

fn divide_by_zero_protection(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let then_expr = &self.body.when_then_expr[0].1;
let binary = then_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.expect("then expression should be a binary expression");

let numerator = binary.left().evaluate(batch)?;
let divisor = binary.right().evaluate(batch)?;

safe_divide(&numerator, &divisor)
}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -1268,6 +1365,7 @@ impl PhysicalExpr for CaseExpr {
EvalMethod::WithExprScalarLookupTable(lookup_table) => {
self.with_lookup_table(batch, lookup_table)
}
EvalMethod::DivideByZeroProtection => self.divide_by_zero_protection(batch),
}
}

Expand Down Expand Up @@ -1389,6 +1487,78 @@ fn replace_with_null(
Ok(with_null)
}

fn safe_divide(
numerator: &ColumnarValue,
divisor: &ColumnarValue,
) -> Result<ColumnarValue> {
if let ColumnarValue::Scalar(div_scalar) = divisor
&& is_scalar_zero(div_scalar)
{
let data_type = numerator.data_type();
return match numerator {
ColumnarValue::Array(arr) => {
Ok(ColumnarValue::Array(new_null_array(&data_type, arr.len())))
}
ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(
ScalarValue::try_new_null(&data_type)?,
)),
};
}

let num_rows = match (numerator, divisor) {
(ColumnarValue::Array(arr), _) => arr.len(),
(_, ColumnarValue::Array(arr)) => arr.len(),
_ => 1,
};

let num_array = numerator.clone().into_array(num_rows)?;
let div_array = divisor.clone().into_array(num_rows)?;

let result = safe_divide_arrays(&num_array, &div_array)?;

if matches!(numerator, ColumnarValue::Scalar(_))
&& matches!(divisor, ColumnarValue::Scalar(_))
{
Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
&result, 0,
)?))
} else {
Ok(ColumnarValue::Array(result))
}
}

fn safe_divide_arrays(numerator: &ArrayRef, divisor: &ArrayRef) -> Result<ArrayRef> {
use arrow::compute::kernels::cmp::eq;
use arrow::compute::kernels::numeric::div;

let zero = ScalarValue::new_zero(divisor.data_type())?.to_scalar()?;
let zero_mask = eq(divisor, &zero)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to make sure you use the condition from the when expression here in order to get the correct result. I added these SLTs and they were not all passing.

query I
SELECT CASE WHEN d != 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d)
----
1
NULL
-1

query I
SELECT CASE WHEN d > 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d)
----
1
NULL
NULL

query I
SELECT CASE WHEN d < 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0), (1, -1)) t(n,d)
----
NULL
NULL
-1


let ones = ScalarValue::new_one(divisor.data_type())?.to_scalar()?;
let safe_divisor = zip(&zero_mask, &ones, divisor)?;

let result = div(&numerator, &safe_divisor)?;

Ok(nullif(&result, &zero_mask)?)
}

fn is_scalar_zero(scalar: &ScalarValue) -> bool {
match scalar {
ScalarValue::Int8(Some(0))
| ScalarValue::Int16(Some(0))
| ScalarValue::Int32(Some(0))
| ScalarValue::Int64(Some(0))
| ScalarValue::UInt8(Some(0))
| ScalarValue::UInt16(Some(0))
| ScalarValue::UInt32(Some(0))
| ScalarValue::UInt64(Some(0)) => true,
ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true,
ScalarValue::Float32(Some(v)) if *v == 0.0 => true,
ScalarValue::Float64(Some(v)) if *v == 0.0 => true,
_ => false,
}
}

/// Create a CASE expression
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
Expand Down Expand Up @@ -2298,6 +2468,65 @@ mod tests {
Ok(())
}

#[test]
fn test_divide_by_zero_protection_specialization() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();

// CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END
let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?;
let then = binary(
lit(25.0f64),
Operator::Divide,
cast(col("a", &schema)?, &schema, Float64)?,
&schema,
)?;

let expr = CaseExpr::try_new(None, vec![(when, then)], None)?;

assert!(
matches!(expr.eval_method, EvalMethod::DivideByZeroProtection),
"Expected DivideByZeroProtection, got {:?}",
expr.eval_method
);

let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_float64_array(&result)?;

let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
assert_eq!(expected, result);

Ok(())
}

#[test]
fn test_divide_by_zero_protection_specialization_not_applied() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();

// CASE WHEN a > 0 THEN b / c ELSE NULL END
// Divisor (c) != checked operand (a), should NOT use specialization
let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?;
let then = binary(
col("b", &schema)?,
Operator::Divide,
col("c", &schema)?,
&schema,
)?;

let expr = CaseExpr::try_new(None, vec![(when, then)], None)?;

assert!(
!matches!(expr.eval_method, EvalMethod::DivideByZeroProtection),
"Should NOT use DivideByZeroProtection when divisor doesn't match"
);

Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
Expand Down