@@ -21,9 +21,7 @@ use std::any::Any;
2121
2222use super :: power:: PowerFunc ;
2323
24- use crate :: utils:: {
25- calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128,
26- } ;
24+ use crate :: utils:: calculate_binary_math;
2725use arrow:: array:: { Array , ArrayRef } ;
2826use arrow:: datatypes:: {
2927 DataType , Decimal32Type , Decimal64Type , Decimal128Type , Decimal256Type , Float16Type ,
@@ -44,7 +42,7 @@ use datafusion_expr::{
4442} ;
4543use datafusion_expr:: { ScalarUDFImpl , Signature , Volatility } ;
4644use datafusion_macros:: user_doc;
47- use num_traits:: Float ;
45+ use num_traits:: { Float , ToPrimitive } ;
4846
4947#[ user_doc(
5048 doc_section( label = "Math Functions" ) ,
@@ -104,109 +102,109 @@ impl LogFunc {
104102 }
105103}
106104
107- /// Binary function to calculate logarithm of Decimal32 `value` using `base` base
108- /// Returns error if base is invalid
109- fn log_decimal32 ( value : i32 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
110- if !base. is_finite ( ) || base. trunc ( ) != base {
111- return Err ( ArrowError :: ComputeError ( format ! (
112- "Log cannot use non-integer base: {base}"
113- ) ) ) ;
114- }
115- if ( base as u32 ) < 2 {
116- return Err ( ArrowError :: ComputeError ( format ! (
117- "Log base must be greater than 1: {base}"
118- ) ) ) ;
119- }
120-
121- // Match f64::log behaviour
122- if value <= 0 {
123- return Ok ( f64:: NAN ) ;
124- }
105+ /// Checks if the base is valid for the efficient integer logarithm algorithm.
106+ #[ inline]
107+ fn is_valid_integer_base ( base : f64 ) -> bool {
108+ base. trunc ( ) == base && base >= 2.0 && base <= u32:: MAX as f64
109+ }
125110
126- if scale < 0 {
127- let actual_value = ( value as f64 ) * 10.0_f64 . powi ( -( scale as i32 ) ) ;
128- Ok ( actual_value. log ( base) )
129- } else {
130- let unscaled_value = decimal32_to_i32 ( value, scale) ?;
131- if unscaled_value <= 0 {
132- return Ok ( f64:: NAN ) ;
133- }
134- let log_value: u32 = unscaled_value. ilog ( base as i32 ) ;
135- Ok ( log_value as f64 )
111+ /// Calculate logarithm for Decimal32 values.
112+ /// For integer bases >= 2 with non-negative scale, uses the efficient u32 ilog algorithm.
113+ /// Otherwise falls back to f64 computation.
114+ fn log_decimal32 ( value : i32 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
115+ if is_valid_integer_base ( base)
116+ && scale >= 0
117+ && let Some ( unscaled) = unscale_to_u32 ( value, scale)
118+ {
119+ return if unscaled > 0 {
120+ Ok ( unscaled. ilog ( base as u32 ) as f64 )
121+ } else {
122+ Ok ( f64:: NAN )
123+ } ;
136124 }
125+ decimal_to_f64 ( value, scale) . map ( |v| v. log ( base) )
137126}
138127
139- /// Binary function to calculate logarithm of Decimal64 `value` using `base` base
140- /// Returns error if base is invalid
128+ /// Calculate logarithm for Decimal64 values.
129+ /// For integer bases >= 2 with non-negative scale, uses the efficient u64 ilog algorithm.
130+ /// Otherwise falls back to f64 computation.
141131fn log_decimal64 ( value : i64 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
142- if !base . is_finite ( ) || base. trunc ( ) != base {
143- return Err ( ArrowError :: ComputeError ( format ! (
144- "Log cannot use non-integer base: {base}"
145- ) ) ) ;
146- }
147- if ( base as u32 ) < 2 {
148- return Err ( ArrowError :: ComputeError ( format ! (
149- "Log base must be greater than 1: {base}"
150- ) ) ) ;
132+ if is_valid_integer_base ( base)
133+ && scale >= 0
134+ && let Some ( unscaled ) = unscale_to_u64 ( value , scale )
135+ {
136+ return if unscaled > 0 {
137+ Ok ( unscaled . ilog ( base as u64 ) as f64 )
138+ } else {
139+ Ok ( f64 :: NAN )
140+ } ;
151141 }
142+ decimal_to_f64 ( value, scale) . map ( |v| v. log ( base) )
143+ }
152144
153- if value <= 0 {
154- return Ok ( f64:: NAN ) ;
145+ /// Calculate logarithm for Decimal128 values.
146+ /// For integer bases >= 2 with non-negative scale, uses the efficient u128 ilog algorithm.
147+ /// Otherwise falls back to f64 computation.
148+ fn log_decimal128 ( value : i128 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
149+ if is_valid_integer_base ( base)
150+ && scale >= 0
151+ && let Some ( unscaled) = unscale_to_u128 ( value, scale)
152+ {
153+ return if unscaled > 0 {
154+ Ok ( unscaled. ilog ( base as u128 ) as f64 )
155+ } else {
156+ Ok ( f64:: NAN )
157+ } ;
155158 }
159+ decimal_to_f64 ( value, scale) . map ( |v| v. log ( base) )
160+ }
156161
157- if scale < 0 {
158- let actual_value = ( value as f64 ) * 10.0_f64 . powi ( -( scale as i32 ) ) ;
159- Ok ( actual_value. log ( base) )
160- } else {
161- let unscaled_value = decimal64_to_i64 ( value, scale) ?;
162- if unscaled_value <= 0 {
163- return Ok ( f64:: NAN ) ;
164- }
165- let log_value: u32 = unscaled_value. ilog ( base as i64 ) ;
166- Ok ( log_value as f64 )
167- }
162+ /// Unscale a Decimal32 value to u32.
163+ #[ inline]
164+ fn unscale_to_u32 ( value : i32 , scale : i8 ) -> Option < u32 > {
165+ let value_u32 = u32:: try_from ( value) . ok ( ) ?;
166+ let divisor = 10u32 . checked_pow ( scale as u32 ) ?;
167+ Some ( value_u32 / divisor)
168168}
169169
170- /// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
171- /// Returns error if base is invalid
172- fn log_decimal128 ( value : i128 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
173- if !base. is_finite ( ) || base. trunc ( ) != base {
174- return Err ( ArrowError :: ComputeError ( format ! (
175- "Log cannot use non-integer base: {base}"
176- ) ) ) ;
177- }
178- if ( base as u32 ) < 2 {
179- return Err ( ArrowError :: ComputeError ( format ! (
180- "Log base must be greater than 1: {base}"
181- ) ) ) ;
182- }
170+ /// Unscale a Decimal64 value to u64.
171+ #[ inline]
172+ fn unscale_to_u64 ( value : i64 , scale : i8 ) -> Option < u64 > {
173+ let value_u64 = u64:: try_from ( value) . ok ( ) ?;
174+ let divisor = 10u64 . checked_pow ( scale as u32 ) ?;
175+ Some ( value_u64 / divisor)
176+ }
183177
184- if value <= 0 {
185- // Reflect f64::log behaviour
186- return Ok ( f64:: NAN ) ;
187- }
178+ /// Unscale a Decimal128 value to u128.
179+ #[ inline]
180+ fn unscale_to_u128 ( value : i128 , scale : i8 ) -> Option < u128 > {
181+ let value_u128 = u128:: try_from ( value) . ok ( ) ?;
182+ let divisor = 10u128 . checked_pow ( scale as u32 ) ?;
183+ Some ( value_u128 / divisor)
184+ }
188185
189- if scale < 0 {
190- let actual_value = ( value as f64 ) * 10.0_f64 . powi ( -( scale as i32 ) ) ;
191- Ok ( actual_value. log ( base) )
192- } else {
193- let unscaled_value = decimal128_to_i128 ( value, scale) ?;
194- if unscaled_value <= 0 {
195- return Ok ( f64:: NAN ) ;
196- }
197- let log_value: u32 = unscaled_value. ilog ( base as i128 ) ;
198- Ok ( log_value as f64 )
199- }
186+ /// Convert a scaled decimal value to f64.
187+ #[ inline]
188+ fn decimal_to_f64 < T : ToPrimitive + Copy > ( value : T , scale : i8 ) -> Result < f64 , ArrowError > {
189+ let value_f64 = value. to_f64 ( ) . ok_or_else ( || {
190+ ArrowError :: ComputeError ( "Cannot convert value to f64" . to_string ( ) )
191+ } ) ?;
192+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
193+ Ok ( value_f64 / scale_factor)
200194}
201195
202- /// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
203- /// Returns error if base is invalid or if value is out of bounds of Decimal128
204196fn log_decimal256 ( value : i256 , scale : i8 , base : f64 ) -> Result < f64 , ArrowError > {
197+ // Try to convert to i128 for the optimized path
205198 match value. to_i128 ( ) {
206- Some ( value) => log_decimal128 ( value, scale, base) ,
207- None => Err ( ArrowError :: NotYetImplemented ( format ! (
208- "Log of Decimal256 larger than Decimal128 is not yet supported: {value}"
209- ) ) ) ,
199+ Some ( v) => log_decimal128 ( v, scale, base) ,
200+ None => {
201+ // For very large Decimal256 values, use f64 computation
202+ let value_f64 = value. to_f64 ( ) . ok_or_else ( || {
203+ ArrowError :: ComputeError ( format ! ( "Cannot convert {value} to f64" ) )
204+ } ) ?;
205+ let scale_factor = 10f64 . powi ( scale as i32 ) ;
206+ Ok ( ( value_f64 / scale_factor) . log ( base) )
207+ }
210208 }
211209}
212210
@@ -1160,7 +1158,8 @@ mod tests {
11601158 }
11611159
11621160 #[ test]
1163- fn test_log_decimal128_wrong_base ( ) {
1161+ fn test_log_decimal128_invalid_base ( ) {
1162+ // Invalid base (-2.0) should return NaN, matching f64::log behavior
11641163 let arg_fields = vec ! [
11651164 Field :: new( "b" , DataType :: Float64 , false ) . into( ) ,
11661165 Field :: new( "x" , DataType :: Decimal128 ( 38 , 0 ) , false ) . into( ) ,
@@ -1175,16 +1174,26 @@ mod tests {
11751174 return_field : Field :: new ( "f" , DataType :: Float64 , true ) . into ( ) ,
11761175 config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
11771176 } ;
1178- let result = LogFunc :: new ( ) . invoke_with_args ( args) ;
1179- assert ! ( result. is_err( ) ) ;
1180- assert_eq ! (
1181- "Arrow error: Compute error: Log base must be greater than 1: -2" ,
1182- result. unwrap_err( ) . to_string( ) . lines( ) . next( ) . unwrap( )
1183- ) ;
1177+ let result = LogFunc :: new ( )
1178+ . invoke_with_args ( args)
1179+ . expect ( "should not error on invalid base" ) ;
1180+
1181+ match result {
1182+ ColumnarValue :: Array ( arr) => {
1183+ let floats = as_float64_array ( & arr)
1184+ . expect ( "failed to convert result to a Float64Array" ) ;
1185+ assert_eq ! ( floats. len( ) , 1 ) ;
1186+ assert ! ( floats. value( 0 ) . is_nan( ) ) ;
1187+ }
1188+ ColumnarValue :: Scalar ( _) => {
1189+ panic ! ( "Expected an array value" )
1190+ }
1191+ }
11841192 }
11851193
11861194 #[ test]
1187- fn test_log_decimal256_error ( ) {
1195+ fn test_log_decimal256_large ( ) {
1196+ // Large Decimal256 values that don't fit in i128 now use f64 fallback
11881197 let arg_field = Field :: new ( "a" , DataType :: Decimal256 ( 38 , 0 ) , false ) . into ( ) ;
11891198 let args = ScalarFunctionArgs {
11901199 args : vec ! [
@@ -1198,11 +1207,26 @@ mod tests {
11981207 return_field : Field :: new ( "f" , DataType :: Float64 , true ) . into ( ) ,
11991208 config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
12001209 } ;
1201- let result = LogFunc :: new ( ) . invoke_with_args ( args) ;
1202- assert ! ( result. is_err( ) ) ;
1203- assert_eq ! (
1204- result. unwrap_err( ) . to_string( ) . lines( ) . next( ) . unwrap( ) ,
1205- "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727"
1206- ) ;
1210+ let result = LogFunc :: new ( )
1211+ . invoke_with_args ( args)
1212+ . expect ( "should handle large Decimal256 via f64 fallback" ) ;
1213+
1214+ match result {
1215+ ColumnarValue :: Array ( arr) => {
1216+ let floats = as_float64_array ( & arr)
1217+ . expect ( "failed to convert result to a Float64Array" ) ;
1218+ assert_eq ! ( floats. len( ) , 1 ) ;
1219+ // The f64 fallback may lose some precision for very large numbers,
1220+ // but we verify we get a reasonable positive result (not NaN/infinity)
1221+ let log_result = floats. value ( 0 ) ;
1222+ assert ! (
1223+ log_result. is_finite( ) && log_result > 0.0 ,
1224+ "Expected positive finite log result, got {log_result}"
1225+ ) ;
1226+ }
1227+ ColumnarValue :: Scalar ( _) => {
1228+ panic ! ( "Expected an array value" )
1229+ }
1230+ }
12071231 }
12081232}
0 commit comments