diff --git a/docs/src/api.md b/docs/src/api.md index 2b87105..c9b7b98 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,4 +21,6 @@ AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft AbstractFFTs.fftshift AbstractFFTs.ifftshift +AbstractFFTs.fftfreq +AbstractFFTs.rfftfreq ``` diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index b4a29e8..d31cde3 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -3,7 +3,7 @@ module AbstractFFTs export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, - fftshift, ifftshift + fftshift, ifftshift, Frequencies, fftfreq, rfftfreq include("definitions.jl") diff --git a/src/definitions.jl b/src/definitions.jl index 07f2d79..2f10d0c 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -3,7 +3,7 @@ using LinearAlgebra using LinearAlgebra: BlasReal import Base: show, summary, size, ndims, length, eltype, - *, inv, \ + *, inv, \, size, step, getindex, iterate # DFT plan where the inputs are an array of eltype T abstract type Plan{T} end @@ -396,6 +396,54 @@ function ifftshift(x,dim) circshift(x, s) end +############################################################################## + + +struct Frequencies{T<:Number} <: AbstractVector{T} + n_nonnegative::Int + n::Int + multiplier::T + + Frequencies(n_nonnegative::Int, n::Int, multiplier::T) where {T<:Number} = begin + 1 ≤ n_nonnegative ≤ n || throw(ArgumentError("Condition 1 ≤ n_nonnegative ≤ n isn't satisfied.")) + return new{T}(n_nonnegative, n, multiplier) + end +end + +unsafe_getindex(x::Frequencies, i::Int) = + (i-1+ifelse(i <= x.n_nonnegative, 0, -x.n))*x.multiplier +@inline function Base.getindex(x::Frequencies, i::Int) + @boundscheck Base.checkbounds(x, i) + unsafe_getindex(x, i) +end + +function Base.iterate(x::Frequencies, i::Int=1) + i > x.n ? nothing : (unsafe_getindex(x,i), i + 1) +end +Base.size(x::Frequencies) = (x.n,) +Base.step(x::Frequencies) = x.multiplier + +""" + fftfreq(n, fs=1) +Return the discrete Fourier transform (DFT) sample frequencies for a DFT of length `n`. The returned +`Frequencies` object is an `AbstractVector` containing the frequency +bin centers at every sample point. `fs` is the sample rate of the +input signal. +""" +fftfreq(n::Int, fs::Number=1) = Frequencies((n+1) >> 1, n, fs/n) + +""" + rfftfreq(n, fs=1) +Return the discrete Fourier transform (DFT) sample frequencies for a real DFT of length `n`. +The returned `Frequencies` object is an `AbstractVector` +containing the frequency bin centers at every sample point. `fs` +is the sample rate of the input signal. +""" +rfftfreq(n::Int, fs::Number=1) = Frequencies((n >> 1)+1, (n >> 1)+1, fs/n) + +fftshift(x::Frequencies) = (x.n_nonnegative-x.n:x.n_nonnegative-1)*x.multiplier + + ############################################################################## """ diff --git a/test/runtests.jl b/test/runtests.jl index 5acd839..020aa68 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,6 +81,25 @@ end @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2) == [5 6 4; 2 3 1] end +@testset "FFT Frequencies" begin + # N even + @test fftfreq(8) == [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125] + @test rfftfreq(8) == [0.0, 0.125, 0.25, 0.375, 0.5] + @test fftshift(fftfreq(8)) == -0.5:0.125:0.375 + + # N odd + @test fftfreq(5) == [0.0, 0.2, 0.4, -0.4, -0.2] + @test rfftfreq(5) == [0.0, 0.2, 0.4] + @test fftshift(fftfreq(5)) == -0.4:0.2:0.4 + + # Sampling Frequency + @test fftfreq(5, 2) == [0.0, 0.4, 0.8, -0.8, -0.4] + # <:Number type compatibility + @test eltype(fftfreq(5, ComplexF64(2))) == ComplexF64 + + @test_throws ArgumentError Frequencies(12, 10, 1) +end + @testset "normalization" begin # normalization should be inferable even if region is only inferred as ::Any, # need to wrap in another function to test this (note that p.region::Any for