1717
1818mod literal_lookup_table;
1919
20- use super :: { Column , Literal } ;
20+ use super :: { CastExpr , Column , Literal } ;
2121use crate :: PhysicalExpr ;
22- use crate :: expressions:: { lit, try_cast} ;
22+ use crate :: expressions:: { BinaryExpr , lit, try_cast} ;
2323use arrow:: array:: * ;
2424use arrow:: compute:: kernels:: zip:: zip;
2525use arrow:: compute:: {
@@ -33,6 +33,7 @@ use datafusion_common::{
3333 internal_datafusion_err, internal_err,
3434} ;
3535use datafusion_expr:: ColumnarValue ;
36+ use datafusion_expr_common:: operator:: Operator ;
3637use indexmap:: { IndexMap , IndexSet } ;
3738use std:: borrow:: Cow ;
3839use std:: hash:: Hash ;
@@ -81,6 +82,14 @@ enum EvalMethod {
8182 ///
8283 /// See [`LiteralLookupTable`] for more details
8384 WithExprScalarLookupTable ( LiteralLookupTable ) ,
85+
86+ /// This is a specialization for divide-by-zero protection pattern:
87+ /// CASE WHEN y > 0 THEN x / y ELSE NULL END
88+ /// CASE WHEN y != 0 THEN x / y ELSE NULL END
89+ ///
90+ /// Instead of evaluating the full CASE expression, it is preferred to directly perform division
91+ /// that return NULL when the divisor is zero.
92+ DivideByZeroProtection ,
8493}
8594
8695/// Implementing hash so we can use `derive` on [`EvalMethod`].
@@ -647,6 +656,20 @@ impl CaseExpr {
647656 return Ok ( EvalMethod :: WithExpression ( body. project ( ) ?) ) ;
648657 }
649658
659+ // Check for divide-by-zero protection pattern:
660+ // CASE WHEN y > 0 THEN x / y ELSE NULL END
661+ if body. when_then_expr . len ( ) == 1 && body. else_expr . is_none ( ) {
662+ let ( when_expr, then_expr) = & body. when_then_expr [ 0 ] ;
663+
664+ if let Some ( checked_operand) = Self :: extract_non_zero_operand ( when_expr)
665+ && let Some ( ( _numerator, divisor) ) =
666+ Self :: extract_division_operands ( then_expr)
667+ && divisor. eq ( & checked_operand)
668+ {
669+ return Ok ( EvalMethod :: DivideByZeroProtection ) ;
670+ }
671+ }
672+
650673 Ok (
651674 if body. when_then_expr . len ( ) == 1
652675 && is_cheap_and_infallible ( & ( body. when_then_expr [ 0 ] . 1 ) )
@@ -681,6 +704,67 @@ impl CaseExpr {
681704 pub fn else_expr ( & self ) -> Option < & Arc < dyn PhysicalExpr > > {
682705 self . body . else_expr . as_ref ( )
683706 }
707+
708+ /// Extract the operand being checked for non-zero from a comparison expression.
709+ /// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`.
710+ fn extract_non_zero_operand (
711+ expr : & Arc < dyn PhysicalExpr > ,
712+ ) -> Option < Arc < dyn PhysicalExpr > > {
713+ let binary = expr. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) ?;
714+
715+ match binary. op ( ) {
716+ // y > 0 or y != 0
717+ Operator :: Gt | Operator :: NotEq if Self :: is_literal_zero ( binary. right ( ) ) => {
718+ Some ( Arc :: clone ( binary. left ( ) ) )
719+ }
720+ // 0 < y or 0 != y
721+ Operator :: Lt | Operator :: NotEq if Self :: is_literal_zero ( binary. left ( ) ) => {
722+ Some ( Arc :: clone ( binary. right ( ) ) )
723+ }
724+ _ => None ,
725+ }
726+ }
727+
728+ /// Extract (numerator, divisor) from a division expression.
729+ fn extract_division_operands (
730+ expr : & Arc < dyn PhysicalExpr > ,
731+ ) -> Option < ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) > {
732+ let binary = expr. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) ?;
733+
734+ if binary. op ( ) == & Operator :: Divide {
735+ let divisor =
736+ if let Some ( cast) = binary. right ( ) . as_any ( ) . downcast_ref :: < CastExpr > ( ) {
737+ Arc :: clone ( cast. expr ( ) )
738+ } else {
739+ Arc :: clone ( binary. right ( ) )
740+ } ;
741+ Some ( ( Arc :: clone ( binary. left ( ) ) , divisor) )
742+ } else {
743+ None
744+ }
745+ }
746+
747+ /// Check if an expression is a literal zero value
748+ fn is_literal_zero ( expr : & Arc < dyn PhysicalExpr > ) -> bool {
749+ if let Some ( lit) = expr. as_any ( ) . downcast_ref :: < Literal > ( ) {
750+ match lit. value ( ) {
751+ ScalarValue :: Int8 ( Some ( 0 ) )
752+ | ScalarValue :: Int16 ( Some ( 0 ) )
753+ | ScalarValue :: Int32 ( Some ( 0 ) )
754+ | ScalarValue :: Int64 ( Some ( 0 ) )
755+ | ScalarValue :: UInt8 ( Some ( 0 ) )
756+ | ScalarValue :: UInt16 ( Some ( 0 ) )
757+ | ScalarValue :: UInt32 ( Some ( 0 ) )
758+ | ScalarValue :: UInt64 ( Some ( 0 ) ) => true ,
759+ ScalarValue :: Float16 ( Some ( v) ) if v. to_f32 ( ) == 0.0 => true ,
760+ ScalarValue :: Float32 ( Some ( v) ) if * v == 0.0 => true ,
761+ ScalarValue :: Float64 ( Some ( v) ) if * v == 0.0 => true ,
762+ _ => false ,
763+ }
764+ } else {
765+ false
766+ }
767+ }
684768}
685769
686770impl CaseBody {
@@ -1170,6 +1254,19 @@ impl CaseExpr {
11701254
11711255 Ok ( result)
11721256 }
1257+
1258+ fn divide_by_zero_protection ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
1259+ let then_expr = & self . body . when_then_expr [ 0 ] . 1 ;
1260+ let binary = then_expr
1261+ . as_any ( )
1262+ . downcast_ref :: < BinaryExpr > ( )
1263+ . expect ( "then expression should be a binary expression" ) ;
1264+
1265+ let numerator = binary. left ( ) . evaluate ( batch) ?;
1266+ let divisor = binary. right ( ) . evaluate ( batch) ?;
1267+
1268+ safe_divide ( & numerator, & divisor)
1269+ }
11731270}
11741271
11751272impl PhysicalExpr for CaseExpr {
@@ -1268,6 +1365,7 @@ impl PhysicalExpr for CaseExpr {
12681365 EvalMethod :: WithExprScalarLookupTable ( lookup_table) => {
12691366 self . with_lookup_table ( batch, lookup_table)
12701367 }
1368+ EvalMethod :: DivideByZeroProtection => self . divide_by_zero_protection ( batch) ,
12711369 }
12721370 }
12731371
@@ -1389,6 +1487,78 @@ fn replace_with_null(
13891487 Ok ( with_null)
13901488}
13911489
1490+ fn safe_divide (
1491+ numerator : & ColumnarValue ,
1492+ divisor : & ColumnarValue ,
1493+ ) -> Result < ColumnarValue > {
1494+ if let ColumnarValue :: Scalar ( div_scalar) = divisor
1495+ && is_scalar_zero ( div_scalar)
1496+ {
1497+ let data_type = numerator. data_type ( ) ;
1498+ return match numerator {
1499+ ColumnarValue :: Array ( arr) => {
1500+ Ok ( ColumnarValue :: Array ( new_null_array ( & data_type, arr. len ( ) ) ) )
1501+ }
1502+ ColumnarValue :: Scalar ( _) => Ok ( ColumnarValue :: Scalar (
1503+ ScalarValue :: try_new_null ( & data_type) ?,
1504+ ) ) ,
1505+ } ;
1506+ }
1507+
1508+ let num_rows = match ( numerator, divisor) {
1509+ ( ColumnarValue :: Array ( arr) , _) => arr. len ( ) ,
1510+ ( _, ColumnarValue :: Array ( arr) ) => arr. len ( ) ,
1511+ _ => 1 ,
1512+ } ;
1513+
1514+ let num_array = numerator. clone ( ) . into_array ( num_rows) ?;
1515+ let div_array = divisor. clone ( ) . into_array ( num_rows) ?;
1516+
1517+ let result = safe_divide_arrays ( & num_array, & div_array) ?;
1518+
1519+ if matches ! ( numerator, ColumnarValue :: Scalar ( _) )
1520+ && matches ! ( divisor, ColumnarValue :: Scalar ( _) )
1521+ {
1522+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: try_from_array (
1523+ & result, 0 ,
1524+ ) ?) )
1525+ } else {
1526+ Ok ( ColumnarValue :: Array ( result) )
1527+ }
1528+ }
1529+
1530+ fn safe_divide_arrays ( numerator : & ArrayRef , divisor : & ArrayRef ) -> Result < ArrayRef > {
1531+ use arrow:: compute:: kernels:: cmp:: eq;
1532+ use arrow:: compute:: kernels:: numeric:: div;
1533+
1534+ let zero = ScalarValue :: new_zero ( divisor. data_type ( ) ) ?. to_scalar ( ) ?;
1535+ let zero_mask = eq ( divisor, & zero) ?;
1536+
1537+ let ones = ScalarValue :: new_one ( divisor. data_type ( ) ) ?. to_scalar ( ) ?;
1538+ let safe_divisor = zip ( & zero_mask, & ones, divisor) ?;
1539+
1540+ let result = div ( & numerator, & safe_divisor) ?;
1541+
1542+ Ok ( nullif ( & result, & zero_mask) ?)
1543+ }
1544+
1545+ fn is_scalar_zero ( scalar : & ScalarValue ) -> bool {
1546+ match scalar {
1547+ ScalarValue :: Int8 ( Some ( 0 ) )
1548+ | ScalarValue :: Int16 ( Some ( 0 ) )
1549+ | ScalarValue :: Int32 ( Some ( 0 ) )
1550+ | ScalarValue :: Int64 ( Some ( 0 ) )
1551+ | ScalarValue :: UInt8 ( Some ( 0 ) )
1552+ | ScalarValue :: UInt16 ( Some ( 0 ) )
1553+ | ScalarValue :: UInt32 ( Some ( 0 ) )
1554+ | ScalarValue :: UInt64 ( Some ( 0 ) ) => true ,
1555+ ScalarValue :: Float16 ( Some ( v) ) if v. to_f32 ( ) == 0.0 => true ,
1556+ ScalarValue :: Float32 ( Some ( v) ) if * v == 0.0 => true ,
1557+ ScalarValue :: Float64 ( Some ( v) ) if * v == 0.0 => true ,
1558+ _ => false ,
1559+ }
1560+ }
1561+
13921562/// Create a CASE expression
13931563pub fn case (
13941564 expr : Option < Arc < dyn PhysicalExpr > > ,
@@ -2298,6 +2468,65 @@ mod tests {
22982468 Ok ( ( ) )
22992469 }
23002470
2471+ #[ test]
2472+ fn test_divide_by_zero_protection_specialization ( ) -> Result < ( ) > {
2473+ let batch = case_test_batch1 ( ) ?;
2474+ let schema = batch. schema ( ) ;
2475+
2476+ // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END
2477+ let when = binary ( col ( "a" , & schema) ?, Operator :: Gt , lit ( 0i32 ) , & schema) ?;
2478+ let then = binary (
2479+ lit ( 25.0f64 ) ,
2480+ Operator :: Divide ,
2481+ cast ( col ( "a" , & schema) ?, & schema, Float64 ) ?,
2482+ & schema,
2483+ ) ?;
2484+
2485+ let expr = CaseExpr :: try_new ( None , vec ! [ ( when, then) ] , None ) ?;
2486+
2487+ assert ! (
2488+ matches!( expr. eval_method, EvalMethod :: DivideByZeroProtection ) ,
2489+ "Expected DivideByZeroProtection, got {:?}" ,
2490+ expr. eval_method
2491+ ) ;
2492+
2493+ let result = expr
2494+ . evaluate ( & batch) ?
2495+ . into_array ( batch. num_rows ( ) )
2496+ . expect ( "Failed to convert to array" ) ;
2497+ let result = as_float64_array ( & result) ?;
2498+
2499+ let expected = & Float64Array :: from ( vec ! [ Some ( 25.0 ) , None , None , Some ( 5.0 ) ] ) ;
2500+ assert_eq ! ( expected, result) ;
2501+
2502+ Ok ( ( ) )
2503+ }
2504+
2505+ #[ test]
2506+ fn test_divide_by_zero_protection_specialization_not_applied ( ) -> Result < ( ) > {
2507+ let batch = case_test_batch1 ( ) ?;
2508+ let schema = batch. schema ( ) ;
2509+
2510+ // CASE WHEN a > 0 THEN b / c ELSE NULL END
2511+ // Divisor (c) != checked operand (a), should NOT use specialization
2512+ let when = binary ( col ( "a" , & schema) ?, Operator :: Gt , lit ( 0i32 ) , & schema) ?;
2513+ let then = binary (
2514+ col ( "b" , & schema) ?,
2515+ Operator :: Divide ,
2516+ col ( "c" , & schema) ?,
2517+ & schema,
2518+ ) ?;
2519+
2520+ let expr = CaseExpr :: try_new ( None , vec ! [ ( when, then) ] , None ) ?;
2521+
2522+ assert ! (
2523+ !matches!( expr. eval_method, EvalMethod :: DivideByZeroProtection ) ,
2524+ "Should NOT use DivideByZeroProtection when divisor doesn't match"
2525+ ) ;
2526+
2527+ Ok ( ( ) )
2528+ }
2529+
23012530 fn make_col ( name : & str , index : usize ) -> Arc < dyn PhysicalExpr > {
23022531 Arc :: new ( Column :: new ( name, index) )
23032532 }
0 commit comments