From bc7ed0323a00841e637f5c28b717f0f0d0135ff2 Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Fri, 16 Jan 2026 13:48:10 +0100 Subject: [PATCH] Use the radix-4 FFT for odd powers of two This seems to cut the runtime for odd powers of two by more the two. --- benchmark/ffta_env/Project.toml | 4 +- src/algos.jl | 66 +++++++++------------------------ src/callgraph.jl | 12 +++--- 3 files changed, 25 insertions(+), 57 deletions(-) diff --git a/benchmark/ffta_env/Project.toml b/benchmark/ffta_env/Project.toml index 19c180c..7aac58c 100644 --- a/benchmark/ffta_env/Project.toml +++ b/benchmark/ffta_env/Project.toml @@ -1,6 +1,4 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" - -[extras] FFTA = "b86e33f2-c0db-4aa1-a6e0-ab43e668529e" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" diff --git a/src/algos.jl b/src/algos.jl index 578c8e1..6505f75 100644 --- a/src/algos.jl +++ b/src/algos.jl @@ -5,25 +5,22 @@ end @inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w)) function fft!(out::AbstractVector{T}, in::AbstractVector{T}, start_out::Int, start_in::Int, d::Direction, t::FFTEnum, g::CallGraph{T}, idx::Int) where T - if t === compositeFFT + if t === COMPOSITE_FFT fft_composite!(out, in, start_out, start_in, d, g, idx) else root = g[idx] - if t == dft + if t == DFT fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, _conj(root.w, d)) else - N = root.sz s_in = root.s_in s_out = root.s_out - if t === pow2FFT - fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d)) - elseif t === pow3FFT + if t === POW2RADIX4_FFT + fft_pow2_radix4!(out, in, root.sz, start_out, s_out, start_in, s_in, _conj(root.w, d)) + elseif t === POW3_FFT p_120 = cispi(T(2)/3) m_120 = cispi(T(4)/3) _p_120, _m_120 = d == FFT_FORWARD ? (p_120, m_120) : (m_120, p_120) - fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d), _m_120, _p_120) - elseif t === pow4FFT - fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d)) + fft_pow3!(out, in, root.sz, start_out, s_out, start_in, s_in, _conj(root.w, d), _m_120, _p_120) else throw(ArgumentError("kernel not implemented")) end @@ -133,9 +130,10 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int end end + """ $(TYPEDSIGNATURES) -Power of 2 FFT, in place +Radix-4 FFT for powers of 2, in place # Arguments `out`: Output vector @@ -148,45 +146,15 @@ Power of 2 FFT, in place `w`: The value `cispi(direction_sign(d) * 2 / N)` """ -function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U} - if N == 2 +function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U} + # If N is 2, compute the size two DFT + @inbounds if N == 2 out[start_out] = in[start_in] + in[start_in + stride_in] out[start_out + stride_out] = in[start_in] - in[start_in + stride_in] return end - m = N ÷ 2 - - fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, w*w) - fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, w*w) - - wj = one(T) - @inbounds for j in 0:m-1 - j1_out = start_out + j*stride_out - j2_out = start_out + (j+m)*stride_out - out_j = out[j1_out] - out[j1_out] = out_j + wj*out[j2_out] - out[j2_out] = out_j - wj*out[j2_out] - wj *= w - end -end - - -""" -$(TYPEDSIGNATURES) -Power of 4 FFT, in place -# Arguments -`out`: Output vector -`in`: Input vector -`N`: Size of the transform -`start_out`: Index of the first element of the output vector -`stride_out`: Stride of the output vector -`start_in`: Index of the first element of the input vector -`stride_in`: Stride of the input vector -`w`: The value `cispi(direction_sign(d) * 2 / N)` - -""" -function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U} + # If N is 4, compute an unrolled radix-2 FFT and return minusi = -sign(imag(w))*im @inbounds if N == 4 xee = in[start_in] @@ -203,6 +171,8 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_ out[start_out + 3*stride_out] = xee_m_xeo - xoe_m_xoo return end + + # ...othersize split the problem in four and recur m = N ÷ 4 w1 = w @@ -210,10 +180,10 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_ w3 = w*w2 w4 = w*w3 - fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w4) - fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w4) - fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w4) - fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w4) + fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w4) + fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w4) + fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w4) + fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w4) wkoe = wkeo = wkoo = one(T) diff --git a/src/callgraph.jl b/src/callgraph.jl index d5e2ad1..3340c00 100644 --- a/src/callgraph.jl +++ b/src/callgraph.jl @@ -1,6 +1,6 @@ @enum Direction FFT_FORWARD=-1 FFT_BACKWARD=1 @enum Pow24 POW2=2 POW4=1 -@enum FFTEnum compositeFFT dft pow2FFT pow3FFT pow4FFT +@enum FFTEnum COMPOSITE_FFT DFT POW3_FFT POW2RADIX4_FFT """ $(TYPEDSIGNATURES) @@ -74,20 +74,20 @@ function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vect pow = _ispow24(N) if !isnothing(pow) push!(workspace, T[]) - push!(nodes, CallGraphNode(0, 0, pow == POW2 ? pow2FFT : pow4FFT, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, POW2RADIX4_FFT, N, s_in, s_out, w)) return 1 end end if N % 3 == 0 if nextpow(3, N) == N push!(workspace, T[]) - push!(nodes, CallGraphNode(0, 0, pow3FFT, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, POW3_FFT, N, s_in, s_out, w)) return 1 end end if N == 1 || Primes.isprime(N) push!(workspace, T[]) - push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w)) return 1 end Ns = [first(x) for x in collect(Primes.factor(N)) for _ in 1:last(x)] @@ -104,12 +104,12 @@ function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vect N1 = N_cp[N1_idx] end N2 = N ÷ N1 - push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w)) sz = length(nodes) push!(workspace, Vector{T}(undef, N)) left_len = CallGraphNode!(nodes, N1, workspace, N2, N2*s_out) right_len = CallGraphNode!(nodes, N2, workspace, N1*s_in, 1) - nodes[sz] = CallGraphNode(1, 1 + left_len, compositeFFT, N, s_in, s_out, w) + nodes[sz] = CallGraphNode(1, 1 + left_len, COMPOSITE_FFT, N, s_in, s_out, w) return 1 + left_len + right_len end