From 8cb30cad56241c8d36578ad551b13e25f40d42cf Mon Sep 17 00:00:00 2001 From: Milan Date: Fri, 2 Sep 2022 17:19:26 +0100 Subject: [PATCH] sz and osz included in plans --- src/fft.jl | 65 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/src/fft.jl b/src/fft.jl index 9907e1a..8a3a8b4 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -207,8 +207,8 @@ generic_idct(a::AbstractArray{T}) where {T <: AbstractFloat} = real(generic_idct for f in (:dct, :dct!, :idct, :idct!) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray{<:AbstractFloats}) = $pf(x) * x - $f(x::AbstractArray{<:AbstractFloats}, region) = $pf(x, region) * x + $f(x::AbstractArray{<:AbstractFloats}) = $pf(x, size(x)) * x + $f(x::AbstractArray{<:AbstractFloats}, region) = $pf(x, size(x), region) * x end end @@ -217,20 +217,25 @@ abstract type DummyPlan{T} <: Plan{T} end for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiDCTPlan) # All plans need an initially undefined pinv field @eval begin - mutable struct $P{T,inplace,G} <: DummyPlan{T} + mutable struct $P{T,inplace,N,G} <: DummyPlan{T} + sz::NTuple{N,Int} + osz::NTuple{N,Int} region::G # region (iterable) of dims that are transformed pinv::DummyPlan{T} - $P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region) + $P{T,inplace,N,G}(sz::NTuple{N,Integer}, region::G) where {T<:AbstractFloats, inplace, N, G} = new(sz,sz,region) end end end for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan) @eval begin - mutable struct $P{T,inplace,G} <: DummyPlan{T} - n::Integer + mutable struct $P{T,inplace,N,G} <: DummyPlan{T} + sz::NTuple{N,Int} + osz::NTuple{N,Int} region::G # region (iterable) of dims that are transformed pinv::DummyPlan{T} - $P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region) + $P{T,inplace,N,G}( sz::NTuple{N,Integer}, + osz::NTuple{N,Integer}, + region::G) where {T<:AbstractFloats, inplace, N, G} = new(sz, osz, region) end end end @@ -238,14 +243,14 @@ end for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan), (:DummyDCTPlan,:DummyiDCTPlan)) @eval begin - plan_inv(p::$Plan{T,inplace,G}) where {T,inplace,G} = $iPlan{T,inplace,G}(p.region) - plan_inv(p::$iPlan{T,inplace,G}) where {T,inplace,G} = $Plan{T,inplace,G}(p.region) + plan_inv(p::$Plan{T,inplace,N,G}) where {T,inplace,N,G} = $iPlan{T,inplace,N,G}(p.sz, p.osz, p.region) + plan_inv(p::$iPlan{T,inplace,N,G}) where {T,inplace,N,G} = $Plan{T,inplace,N,G}(p.sz, p.osz, p.region) end end # Specific for rfft, irfft and brfft: -plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,inplace,G}(p.n, p.region) -plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,inplace,G}(p.n, p.region) +plan_inv(p::DummyirFFTPlan{T,inplace,N,G}) where {T,inplace,N,G} = DummyrFFTPlan{T,inplace,N,G}(p.sz, p.osz, p.region) +plan_inv(p::DummyrFFTPlan{T,inplace,N,G}) where {T,inplace,N,G} = DummyirFFTPlan{T,inplace,N,G}(p.sz, p.osz, p.region) @@ -256,26 +261,28 @@ for (Plan,ff,ff!) in ((:DummyFFTPlan,:generic_fft,:generic_fft!), (:DummyDCTPlan,:generic_dct,:generic_dct!), (:DummyiDCTPlan,:generic_idct,:generic_idct!)) @eval begin - *(p::$Plan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff!(x, p.region) - *(p::$Plan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff(x, p.region) + + *(p::$Plan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff!(x, p.sz, p.region) + *(p::$Plan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = $ff(x, p.sz, p.region) + function mul!(C::StridedVector, p::$Plan, x::StridedVector) - C[:] = $ff(x, p.region) + C[:] = $ff(x, p.sz[1], p.region) C end end end # Specific for irfft and brfft: -*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.n, p.region) -*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.n, p.region) +*(p::DummyirFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft!(x, p.sz, p.region) +*(p::DummyirFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_irfft(x, p.sz, p.region) function mul!(C::StridedVector, p::DummyirFFTPlan, x::StridedVector) - C[:] = generic_irfft(x, p.n, p.region) + C[:] = generic_irfft(x, p.sz[1], p.region) C end -*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.n, p.region) -*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.n, p.region) +*(p::DummybrFFTPlan{T,true}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft!(x, p.sz, p.region) +*(p::DummybrFFTPlan{T,false}, x::StridedArray{T,N}) where {T<:AbstractFloats,N} = generic_brfft(x, p.sz, p.region) function mul!(C::StridedVector, p::DummybrFFTPlan, x::StridedVector) - C[:] = generic_brfft(x, p.n, p.region) + C[:] = generic_brfft(x, p.sz[1], p.region) C end @@ -286,23 +293,23 @@ end # This is the reason for using StridedArray below. We also have to carefully # distinguish between real and complex arguments. -plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false,typeof(region)}(region) -plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true,typeof(region)}(region) +plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false,typeof(region)}(size(x),region) +plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true,typeof(region)}(size(x),region) -plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false,typeof(region)}(region) -plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true,typeof(region)}(region) +plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false,typeof(region)}(size(x),region) +plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true,typeof(region)}(size(x),region) # The ifft plans are automatically provided in terms of the bfft plans above. # plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false,typeof(region)}(region) # plan_ifft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},true,typeof(region)}(region) -plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false,typeof(region)}(region) -plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true,typeof(region)}(region) +plan_dct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,false,typeof(region)}(size(x),region) +plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan{T,true,typeof(region)}(size(x),region) -plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region) -plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region) +plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(size(x),region) +plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(size(x),region) -plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false,typeof(region)}(length(x), region) +plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false,typeof(region)}(size(x), region) plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false,typeof(region)}(n, region) # A plan for irfft is created in terms of a plan for brfft.