Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FFTA = "b86e33f2-c0db-4aa1-a6e0-ab43e668529e"

[compat]
Documenter = "1"
175 changes: 82 additions & 93 deletions src/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,32 @@ abstract type FFTAPlan{T,N} <: AbstractFFTs.Plan{T} end

struct FFTAInvPlan{T,N} <: FFTAPlan{T,N} end

struct FFTAPlan_cx{T,N} <: FFTAPlan{T,N}
callgraph::NTuple{N, CallGraph{T}}
region::Union{Int,AbstractVector{<:Int}}
struct FFTAPlan_cx{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N}
callgraph::NTuple{N,CallGraph{T}}
region::R
dir::Direction
pinv::FFTAInvPlan{T}
pinv::FFTAInvPlan{T,N}
end
function FFTAPlan_cx{T,N}(
cg::NTuple{N,CallGraph{T}}, r::R,
dir::Direction, pinv::FFTAInvPlan{T,N}
) where {T,N,R<:Union{Int,AbstractVector{Int}}}
FFTAPlan_cx{T,N,R}(cg, r, dir, pinv)
end

struct FFTAPlan_re{T,N} <: FFTAPlan{T,N}
callgraph::NTuple{N, CallGraph{T}}
region::Union{Int,AbstractVector{<:Int}}
struct FFTAPlan_re{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N}
callgraph::NTuple{N,CallGraph{T}}
region::R
dir::Direction
pinv::FFTAInvPlan{T}
pinv::FFTAInvPlan{T,N}
flen::Int
end
function FFTAPlan_re{T,N}(
cg::NTuple{N,CallGraph{T}}, r::R,
dir::Direction, pinv::FFTAInvPlan{T,N}, flen::Int
) where {T,N,R<:Union{Int,AbstractVector{Int}}}
FFTAPlan_re{T,N,R}(cg, r, dir, pinv, flen)
end

Base.size(p::FFTAPlan_cx, i::Int) = i <= length(p.callgraph) ? first(p.callgraph[i].nodes).sz : 1
function Base.size(p::FFTAPlan_re{<:Any,1}, i::Int)
Expand All @@ -40,80 +52,72 @@ function Base.size(p::FFTAPlan_re{<:Any,2}, i::Int)
end
Base.size(p::FFTAPlan{<:Any,N}) where N = ntuple(Base.Fix1(size, p), Val{N}())

Base.complex(p::FFTAPlan_re{T,N}) where {T,N} = FFTAPlan_cx{T,N}(p.callgraph, p.region, p.dir, p.pinv)
Base.complex(p::FFTAPlan_re{T,N,R}) where {T,N,R} = FFTAPlan_cx{T,N,R}(p.callgraph, p.region, p.dir, p.pinv)

function AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex, N}
FFTN = length(region)
if FFTN == 1
g = CallGraph{T}(size(x,region[]))
pinv = FFTAInvPlan{T,FFTN}()
return FFTAPlan_cx{T,FFTN}((g,), region, FFT_FORWARD, pinv)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{T}(size(x,region[1]))
g2 = CallGraph{T}(size(x,region[2]))
pinv = FFTAInvPlan{T,FFTN}()
return FFTAPlan_cx{T,FFTN}((g1,g2), region, FFT_FORWARD, pinv)
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
end
AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} =
_plan_fft(x, region, FFT_FORWARD; kwargs...)

AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} =
_plan_fft(x, region, FFT_BACKWARD; kwargs...)

function AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex,N}
function _plan_fft(x::AbstractArray{T,N}, region::R, dir::Direction; kwargs...) where {T<:Complex,N,R}
FFTN = length(region)
if FFTN == 1
g = CallGraph{T}(size(x,region[]))
pinv = FFTAInvPlan{T,FFTN}()
return FFTAPlan_cx{T,FFTN}((g,), region, FFT_BACKWARD, pinv)
R1 = Int(region[])
g = CallGraph{T}(size(x, R1))
pinv = FFTAInvPlan{T,1}()
return FFTAPlan_cx{T,1,Int}((g,), R1, dir, pinv)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{T}(size(x,region[1]))
g2 = CallGraph{T}(size(x,region[2]))
pinv = FFTAInvPlan{T,FFTN}()
return FFTAPlan_cx{T,FFTN}((g1,g2), region, FFT_BACKWARD, pinv)
g1 = CallGraph{T}(size(x, region[1]))
g2 = CallGraph{T}(size(x, region[2]))
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_cx{T,2,R}((g1, g2), region, dir, pinv)
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
end

function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real,N}
function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Real,N,R}
FFTN = length(region)
if FFTN == 1
n = size(x, region[])
R1 = Int(region[])
n = size(x, R1)
# For even length problems, we solve the real problem with
# two n/2 complex FFTs followed by a butterfly. For odd size
# problems, we just solve the problem as a single complex
nn = iseven(n) ? n >> 1 : n
g = CallGraph{Complex{T}}(nn)
pinv = FFTAInvPlan{Complex{T},FFTN}()
return FFTAPlan_re{Complex{T},FFTN}(tuple(g), region, FFT_FORWARD, pinv, n)
pinv = FFTAInvPlan{Complex{T},1}()
return FFTAPlan_re{Complex{T},1,Int}((g,), R1, FFT_FORWARD, pinv, n)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{Complex{T}}(size(x,region[1]))
g2 = CallGraph{Complex{T}}(size(x,region[2]))
pinv = FFTAInvPlan{Complex{T},FFTN}()
return FFTAPlan_re{Complex{T},FFTN}(tuple(g1,g2), region, FFT_FORWARD, pinv, size(x,region[1]))
g1 = CallGraph{Complex{T}}(size(x, region[1]))
g2 = CallGraph{Complex{T}}(size(x, region[2]))
pinv = FFTAInvPlan{Complex{T},2}()
return FFTAPlan_re{Complex{T},2,R}((g1, g2), region, FFT_FORWARD, pinv, size(x, region[1]))
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
end

function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region; kwargs...)::FFTAPlan_re{T} where {T,N}
function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region::R; kwargs...) where {T,N,R}
FFTN = length(region)
if FFTN == 1
# For even length problems, we solve the real problem with
# two n/2 complex FFTs followed by a butterfly. For odd size
# problems, we just solve the problem as a single complex
R1 = Int(region[])
nn = iseven(len) ? len >> 1 : len
g = CallGraph{T}(nn)
pinv = FFTAInvPlan{T,FFTN}()
return FFTAPlan_re{T,FFTN}((g,), region, FFT_BACKWARD, pinv, len)
pinv = FFTAInvPlan{T,1}()
return FFTAPlan_re{T,1,Int}((g,), R1, FFT_BACKWARD, pinv, len)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{T}(len)
g2 = CallGraph{T}(size(x,region[2]))
pinv = FFTAInvPlan{T,FFTN}()
return FFTAPlan_re{T,FFTN}((g1,g2), region, FFT_BACKWARD, pinv, len)
g2 = CallGraph{T}(size(x, region[2]))
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_re{T,2,R}((g1, g2), region, FFT_BACKWARD, pinv, len)
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
Expand Down Expand Up @@ -175,7 +179,7 @@ function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,2}, x::Abstr
R2 = CartesianIndices(size(x)[p.region[1]+1:p.region[2]-1])
R3 = CartesianIndices(size(x)[p.region[2]+1:end])
y_tmp = similar(y, axes(y)[p.region])
rows,cols = size(x)[p.region]
rows, cols = size(x)[p.region]
# Introduce function barrier here since the variables used in the loop ranges aren't inferred. This
# is partly because the region field of the plan is abstractly typed but even if that wasn't the case,
# it might be a bit tricky to construct the Rxs in an inferred way.
Expand All @@ -195,7 +199,7 @@ function _mul_loop!(
)
for I3 in R3, I2 in R2, I1 in R1
for k in 1:cols
@views fft!(y_tmp[:,k], x[I1,:,I2,k,I3], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
@views fft!(y_tmp[:,k], x[I1,:,I2,k,I3], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
end

for k in 1:rows
Expand All @@ -212,7 +216,7 @@ function Base.:*(p::FFTAPlan_cx{T,1}, x::AbstractVector{T}) where {T<:Complex}
y
end

function Base.:*(p::FFTAPlan_cx{T,N1}, x::AbstractArray{T,N2}) where {T<:Complex, N1, N2}
function Base.:*(p::FFTAPlan_cx{T,N1}, x::AbstractArray{T,N2}) where {T<:Complex,N1,N2}
y = similar(x)
LinearAlgebra.mul!(y, p, x)
y
Expand Down Expand Up @@ -256,38 +260,29 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R
# Construct the result by first constructing the elements of the
# real and imaginary part, followed by the usual radix-2 assembly,
# see eq (9)
@inbounds begin
y1 = y[1]
y[1] = real(y1) + imag(y1)
y[end] = real(y1) - imag(y1)
for j in 2:((m >> 1) + 1)
yj = y[j]
yjr, yji = real(yj), imag(yj)
ymj = y[m - j + 2]
ymjr, ymji = real(ymj), imag(ymj)
XX = complex(
(yjr + ymjr) * T(0.5),
(yji - ymji) * T(0.5),
)
XY = complex(
(ymji + yji) * T(0.5),
(ymjr - yjr) * T(0.5),
)
y[j] = XX + wj*XY
y[m - j + 2] = conj(XX - wj*XY)
wj *= w
end
y1 = y[1]
y[1] = real(y1) + imag(y1)
y[end] = real(y1) - imag(y1)

@inbounds for j in 2:((m >> 1) + 1)
yj = y[j]
ymj = y[m-j+2]
XX = T(0.5) * ( yj + conj(ymj))
XY = T(0.5) * (-yj + conj(ymj)) * im
y[j] = XX + wj * XY
y[m-j+2] = conj(XX - wj * XY)
wj *= w
end
return y
else
# when the problem cannot be split in two equal size chunks we
# convert the problem to a complex fft and truncate the redundant
# part of the result vector
x_c = similar(x, Complex{T})
copy!(x_c, x)
y = similar(x_c)
copyto!(x_c, x)
LinearAlgebra.mul!(y, complex(p), x_c)
return y[1:end÷2 + 1]
return y[1:end÷2+1]
end
end
throw(ArgumentError("only FFT_FORWARD supported for real vectors"))
Expand All @@ -308,16 +303,10 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex}
(real(x[1]) - real(x[end]))
)
for j in 2:((m >> 1) + 1)
XX = x[j] + conj(x[m - j + 2])
XY = wj*(x[j] - conj(x[m - j + 2]))
x_tmp[j] = complex(
real(XX) - imag(XY),
real(XY) + imag(XX)
)
x_tmp[m - j + 2] = complex(
real(XX) + imag(XY),
real(XY) - imag(XX)
)
XX = x[j] + conj(x[m-j+2])
XY = wj * (x[j] - conj(x[m-j+2]))
x_tmp[j] = XX + im * XY
x_tmp[m-j+2] = conj(XX - im * XY)
wj *= w
end
y_c = complex(p) * x_tmp
Expand All @@ -328,8 +317,8 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex}
end
else
x_tmp = similar(x, n)
x_tmp[1:end÷2 + 1] .= x
x_tmp[end÷2 + 2:end] .= iseven(n) ? conj.(x[end-1:-1:2]) : conj.(x[end:-1:2])
x_tmp[1:end÷2+1] .= x
x_tmp[end÷2+2:end] .= iseven(n) ? conj.(x[end-1:-1:2]) : conj.(x[end:-1:2])
y = similar(x_tmp)
LinearAlgebra.mul!(y, complex(p), x_tmp)
return real(y)
Expand All @@ -340,32 +329,32 @@ end

#### 1D plan ND array
##### Forward
function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractArray{T,N}) where {T<:Real, N}
function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractArray{T,N}) where {T<:Real,N}
Base.require_one_based_indexing(x)
if p.dir === FFT_FORWARD
return mapslices(Base.Fix1(*, p), x; dims = p.region[1])
return mapslices(Base.Fix1(*, p), x; dims=p.region[1])
end
throw(ArgumentError("only FFT_FORWARD supported for real arrays"))
end

##### Backward
function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractArray{T,N}) where {T<:Complex, N}
function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractArray{T,N}) where {T<:Complex,N}
Base.require_one_based_indexing(x)
if p.flen ÷ 2 + 1 != size(x, p.region[])
throw(DimensionMismatch("real 1D plan has size $(p.flen). Dimension of input array along region $(p.region[]) should have size $(size(p, p.region[]) ÷ 2 + 1), but has size $(size(x, p.region[]))"))
end
if p.dir === FFT_BACKWARD
return mapslices(Base.Fix1(*, p), x; dims = p.region[1])
return mapslices(Base.Fix1(*, p), x; dims=p.region[1])
end
throw(ArgumentError("only FFT_BACKWARD supported for complex arrays"))
end

#### 2D plan ND array
##### Forward
function Base.:*(p::FFTAPlan_re{Complex{T},2}, x::AbstractArray{T,N}) where {T<:Real, N}
function Base.:*(p::FFTAPlan_re{Complex{T},2}, x::AbstractArray{T,N}) where {T<:Real,N}
Base.require_one_based_indexing(x)
if p.dir === FFT_FORWARD
half_1 = 1:(p.flen ÷ 2 + 1)
half_1 = 1:(p.flen÷2+1)
x_c = similar(x, Complex{T})
copy!(x_c, x)
y = similar(x_c)
Expand All @@ -376,13 +365,13 @@ function Base.:*(p::FFTAPlan_re{Complex{T},2}, x::AbstractArray{T,N}) where {T<:
end

##### Backward
function Base.:*(p::FFTAPlan_re{T,2}, x::AbstractArray{T,N}) where {T<:Complex, N}
function Base.:*(p::FFTAPlan_re{T,2}, x::AbstractArray{T,N}) where {T<:Complex,N}
Base.require_one_based_indexing(x)
if size(p, 1) ÷ 2 + 1 != size(x, p.region[1])
throw(DimensionMismatch("real 2D plan has size $(size(p)). First transform dimension of input array should have size ($(size(p, 1) ÷ 2 + 1)), but has size $(size(x, p.region[1]))"))
end
if p.dir === FFT_BACKWARD
res_size = ntuple(i->ifelse(i==p.region[1], p.flen, size(x,i)), ndims(x))
res_size = ntuple(i -> ifelse(i == p.region[1], p.flen, size(x, i)), Val(N))
# for the inverse transformation we have to reconstruct the full array
half_1 = 1:(p.flen ÷ 2 + 1)
half_2 = half_1[end]+1:p.flen
Expand Down
2 changes: 1 addition & 1 deletion test/argument_checking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
end
end

@testset "mismatch besteen input and output arrays" begin
@testset "mismatch between input and output arrays" begin
@testset "1D plan 1D array" begin
x1 = complex(randn(3))
y1 = similar(x1, length(x1) + 1)
Expand Down
4 changes: 2 additions & 2 deletions test/onedim/complex_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using FFTA, Test
end

@testset "1D plan, 1D array. Size: $n" for n in 1:64
x = complex.(randn(n), randn(n))
x = randn(ComplexF64, n)
# Assuming that fft works since it is tested independently
y = fft(x)

Expand All @@ -23,7 +23,7 @@ end
end

@testset "1D plan, ND array. Size: $n" for n in 1:64
x = complex.(randn(n, n + 1, n + 2), randn(n, n + 1, n + 2))
x = randn(ComplexF64, n, n + 1, n + 2)

@testset "against 1D array with mapslices, r=$r" for r in 1:3
@test bfft(x, r) == mapslices(bfft, x; dims = r)
Expand Down
4 changes: 2 additions & 2 deletions test/onedim/complex_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using FFTA, Test
end

@testset "1D plan, 1D array. Size: $n" for n in 1:64
x = complex.(randn(n), randn(n))
x = randn(ComplexF64, n)

@testset "against naive implementation" begin
@test naive_1d_fourier_transform(x, FFTA.FFT_FORWARD) ≈ fft(x)
Expand All @@ -23,7 +23,7 @@ end
end

@testset "1D plan, ND array. Size: $n" for n in 1:64
x = complex.(randn(n, n + 1, n + 2), randn(n, n + 1, n + 2))
x = randn(ComplexF64, n, n + 1, n + 2)

@testset "against 1D array with mapslices, r=$r" for r in 1:3
@test fft(x, r) == mapslices(fft, x; dims = r)
Expand Down
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Test, Random, FFTA

macro test_allocations(args)
if Base.VERSION >= v"1.9"
:(@allocations($(esc(args))))
@static if Base.VERSION >= v"1.9"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own reference, what's the effect of this change? I know basic macros, and I'm guessing that the point is to ensure the @allocations symbol doesn't appear in code run for Julia <1.9, but I was under the impression that the revised code amounts to the same thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@static causes the if statement to be evaluated only once during macro expansion, so after that the macro would be the code in only one of the branches. I suspected weird "action-at-a-distance" machinery was involved in allocating so I wanted to minimise extraneous code paths.

That was also the motivation for escaping the @allocations macro.
How much this actually works, not clear, but it's not harmful or incorrect and the tests pass with this change.

ex = Expr(:macrocall, Symbol("@allocations"), __source__, args)
return esc(ex)
else
:(0)
end
Expand Down
Loading
Loading