Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions benchmark/ffta_env/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"

[extras]
FFTA = "b86e33f2-c0db-4aa1-a6e0-ab43e668529e"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
66 changes: 18 additions & 48 deletions src/algos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,22 @@ end
@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w))

function fft!(out::AbstractVector{T}, in::AbstractVector{T}, start_out::Int, start_in::Int, d::Direction, t::FFTEnum, g::CallGraph{T}, idx::Int) where T
if t === compositeFFT
if t === COMPOSITE_FFT
fft_composite!(out, in, start_out, start_in, d, g, idx)
else
root = g[idx]
if t == dft
if t == DFT
fft_dft!(out, in, root.sz, start_out, root.s_out, start_in, root.s_in, _conj(root.w, d))
else
N = root.sz
s_in = root.s_in
s_out = root.s_out
if t === pow2FFT
fft_pow2!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d))
elseif t === pow3FFT
if t === POW2RADIX4_FFT
fft_pow2_radix4!(out, in, root.sz, start_out, s_out, start_in, s_in, _conj(root.w, d))
elseif t === POW3_FFT
p_120 = cispi(T(2)/3)
m_120 = cispi(T(4)/3)
_p_120, _m_120 = d == FFT_FORWARD ? (p_120, m_120) : (m_120, p_120)
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d), _m_120, _p_120)
elseif t === pow4FFT
fft_pow4!(out, in, N, start_out, s_out, start_in, s_in, _conj(root.w, d))
fft_pow3!(out, in, root.sz, start_out, s_out, start_in, s_in, _conj(root.w, d), _m_120, _p_120)
else
throw(ArgumentError("kernel not implemented"))
end
Expand Down Expand Up @@ -133,9 +130,10 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int
end
end


"""
$(TYPEDSIGNATURES)
Power of 2 FFT, in place
Radix-4 FFT for powers of 2, in place

# Arguments
`out`: Output vector
Expand All @@ -148,45 +146,15 @@ Power of 2 FFT, in place
`w`: The value `cispi(direction_sign(d) * 2 / N)`

"""
function fft_pow2!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
if N == 2
function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
# If N is 2, compute the size two DFT
@inbounds if N == 2
out[start_out] = in[start_in] + in[start_in + stride_in]
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
return
end
m = N ÷ 2

fft_pow2!(out, in, m, start_out , stride_out, start_in , stride_in*2, w*w)
fft_pow2!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*2, w*w)

wj = one(T)
@inbounds for j in 0:m-1
j1_out = start_out + j*stride_out
j2_out = start_out + (j+m)*stride_out
out_j = out[j1_out]
out[j1_out] = out_j + wj*out[j2_out]
out[j2_out] = out_j - wj*out[j2_out]
wj *= w
end
end


"""
$(TYPEDSIGNATURES)
Power of 4 FFT, in place

# Arguments
`out`: Output vector
`in`: Input vector
`N`: Size of the transform
`start_out`: Index of the first element of the output vector
`stride_out`: Stride of the output vector
`start_in`: Index of the first element of the input vector
`stride_in`: Stride of the input vector
`w`: The value `cispi(direction_sign(d) * 2 / N)`

"""
function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
# If N is 4, compute an unrolled radix-2 FFT and return
minusi = -sign(imag(w))*im
@inbounds if N == 4
xee = in[start_in]
Expand All @@ -203,17 +171,19 @@ function fft_pow4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
out[start_out + 3*stride_out] = xee_m_xeo - xoe_m_xoo
return
end

# ...othersize split the problem in four and recur
m = N ÷ 4

w1 = w
w2 = w*w1
w3 = w*w2
w4 = w*w3

fft_pow4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w4)
fft_pow4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w4)
fft_pow4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w4)
fft_pow4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w4)
fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w4)
fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w4)
fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w4)
fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w4)

wkoe = wkeo = wkoo = one(T)

Expand Down
12 changes: 6 additions & 6 deletions src/callgraph.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@enum Direction FFT_FORWARD=-1 FFT_BACKWARD=1
@enum Pow24 POW2=2 POW4=1
@enum FFTEnum compositeFFT dft pow2FFT pow3FFT pow4FFT
@enum FFTEnum COMPOSITE_FFT DFT POW3_FFT POW2RADIX4_FFT

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -74,20 +74,20 @@ function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vect
pow = _ispow24(N)
if !isnothing(pow)
push!(workspace, T[])
push!(nodes, CallGraphNode(0, 0, pow == POW2 ? pow2FFT : pow4FFT, N, s_in, s_out, w))
push!(nodes, CallGraphNode(0, 0, POW2RADIX4_FFT, N, s_in, s_out, w))
return 1
end
end
if N % 3 == 0
if nextpow(3, N) == N
push!(workspace, T[])
push!(nodes, CallGraphNode(0, 0, pow3FFT, N, s_in, s_out, w))
push!(nodes, CallGraphNode(0, 0, POW3_FFT, N, s_in, s_out, w))
return 1
end
end
if N == 1 || Primes.isprime(N)
push!(workspace, T[])
push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out, w))
push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w))
return 1
end
Ns = [first(x) for x in collect(Primes.factor(N)) for _ in 1:last(x)]
Expand All @@ -104,12 +104,12 @@ function CallGraphNode!(nodes::Vector{CallGraphNode{T}}, N::Int, workspace::Vect
N1 = N_cp[N1_idx]
end
N2 = N ÷ N1
push!(nodes, CallGraphNode(0, 0, dft, N, s_in, s_out, w))
push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w))
sz = length(nodes)
push!(workspace, Vector{T}(undef, N))
left_len = CallGraphNode!(nodes, N1, workspace, N2, N2*s_out)
right_len = CallGraphNode!(nodes, N2, workspace, N1*s_in, 1)
nodes[sz] = CallGraphNode(1, 1 + left_len, compositeFFT, N, s_in, s_out, w)
nodes[sz] = CallGraphNode(1, 1 + left_len, COMPOSITE_FFT, N, s_in, s_out, w)
return 1 + left_len + right_len
end

Expand Down
Loading