@@ -20,6 +20,7 @@ use arrow::datatypes::{ArrowNativeTypeOp, Field, Schema};
2020use arrow:: record_batch:: RecordBatch ;
2121use arrow:: util:: test_util:: seedable_rng;
2222use criterion:: { BenchmarkId , Criterion , criterion_group, criterion_main} ;
23+ use datafusion_common:: ScalarValue ;
2324use datafusion_expr:: Operator ;
2425use datafusion_physical_expr:: expressions:: { BinaryExpr , case, col, lit} ;
2526use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
@@ -93,6 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) {
9394 run_benchmarks ( c, & make_batch ( 8192 , 100 ) ) ;
9495
9596 benchmark_lookup_table_case_when ( c, 8192 ) ;
97+ benchmark_divide_by_zero_protection ( c, 8192 ) ;
9698}
9799
98100fn run_benchmarks ( c : & mut Criterion , batch : & RecordBatch ) {
@@ -517,5 +519,106 @@ fn benchmark_lookup_table_case_when(c: &mut Criterion, batch_size: usize) {
517519 }
518520}
519521
522+ fn benchmark_divide_by_zero_protection ( c : & mut Criterion , batch_size : usize ) {
523+ let mut group = c. benchmark_group ( "divide_by_zero_protection" ) ;
524+
525+ for zero_percentage in [ 0.0 , 0.1 , 0.5 , 0.9 ] {
526+ let rng = & mut seedable_rng ( ) ;
527+
528+ let numerator: Int32Array =
529+ ( 0 ..batch_size) . map ( |_| Some ( rng. random :: < i32 > ( ) ) ) . collect ( ) ;
530+
531+ let divisor_values: Vec < Option < i32 > > = ( 0 ..batch_size)
532+ . map ( |_| {
533+ let roll: f32 = rng. random ( ) ;
534+ if roll < zero_percentage {
535+ Some ( 0 )
536+ } else {
537+ let mut val = rng. random :: < i32 > ( ) ;
538+ while val == 0 {
539+ val = rng. random :: < i32 > ( ) ;
540+ }
541+ Some ( val)
542+ }
543+ } )
544+ . collect ( ) ;
545+
546+ let divisor: Int32Array = divisor_values. iter ( ) . cloned ( ) . collect ( ) ;
547+ let divisor_copy: Int32Array = divisor_values. iter ( ) . cloned ( ) . collect ( ) ;
548+
549+ let schema = Arc :: new ( Schema :: new ( vec ! [
550+ Field :: new( "numerator" , numerator. data_type( ) . clone( ) , true ) ,
551+ Field :: new( "divisor" , divisor. data_type( ) . clone( ) , true ) ,
552+ Field :: new( "divisor_copy" , divisor_copy. data_type( ) . clone( ) , true ) ,
553+ ] ) ) ;
554+
555+ let batch = RecordBatch :: try_new (
556+ Arc :: clone ( & schema) ,
557+ vec ! [
558+ Arc :: new( numerator) ,
559+ Arc :: new( divisor) ,
560+ Arc :: new( divisor_copy) ,
561+ ] ,
562+ )
563+ . unwrap ( ) ;
564+
565+ let numerator_col = col ( "numerator" , & batch. schema ( ) ) . unwrap ( ) ;
566+ let divisor_col = col ( "divisor" , & batch. schema ( ) ) . unwrap ( ) ;
567+ let divisor_copy_col = col ( "divisor_copy" , & batch. schema ( ) ) . unwrap ( ) ;
568+
569+ group. bench_function (
570+ format ! (
571+ "{} rows, {}% zeros: DivideByZeroProtection" ,
572+ batch_size,
573+ ( zero_percentage * 100.0 ) as i32
574+ ) ,
575+ |b| {
576+ let when = Arc :: new ( BinaryExpr :: new (
577+ Arc :: clone ( & divisor_col) ,
578+ Operator :: Gt ,
579+ lit ( 0i32 ) ,
580+ ) ) ;
581+ let then = Arc :: new ( BinaryExpr :: new (
582+ Arc :: clone ( & numerator_col) ,
583+ Operator :: Divide ,
584+ Arc :: clone ( & divisor_col) ,
585+ ) ) ;
586+ let else_null: Arc < dyn PhysicalExpr > = lit ( ScalarValue :: Int32 ( None ) ) ;
587+ let expr =
588+ Arc :: new ( case ( None , vec ! [ ( when, then) ] , Some ( else_null) ) . unwrap ( ) ) ;
589+
590+ b. iter ( || black_box ( expr. evaluate ( black_box ( & batch) ) . unwrap ( ) ) )
591+ } ,
592+ ) ;
593+
594+ group. bench_function (
595+ format ! (
596+ "{} rows, {}% zeros: ExpressionOrExpression" ,
597+ batch_size,
598+ ( zero_percentage * 100.0 ) as i32
599+ ) ,
600+ |b| {
601+ let when = Arc :: new ( BinaryExpr :: new (
602+ Arc :: clone ( & divisor_copy_col) ,
603+ Operator :: Gt ,
604+ lit ( 0i32 ) ,
605+ ) ) ;
606+ let then = Arc :: new ( BinaryExpr :: new (
607+ Arc :: clone ( & numerator_col) ,
608+ Operator :: Divide ,
609+ Arc :: clone ( & divisor_col) ,
610+ ) ) ;
611+ let else_null: Arc < dyn PhysicalExpr > = lit ( ScalarValue :: Int32 ( None ) ) ;
612+ let expr =
613+ Arc :: new ( case ( None , vec ! [ ( when, then) ] , Some ( else_null) ) . unwrap ( ) ) ;
614+
615+ b. iter ( || black_box ( expr. evaluate ( black_box ( & batch) ) . unwrap ( ) ) )
616+ } ,
617+ ) ;
618+ }
619+
620+ group. finish ( ) ;
621+ }
622+
520623criterion_group ! ( benches, criterion_benchmark) ;
521624criterion_main ! ( benches) ;
0 commit comments