Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Support for Operators #9

Merged
merged 47 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ea6bd62
Restructure FFTOP to work on GPU
nHackel Apr 9, 2024
9a2069c
Add GPU support to GradientOp
nHackel Apr 11, 2024
ae5fd8c
Added S kwarg to WaveletOp
nHackel Apr 11, 2024
51a4e0f
Add GPU support to ProdOp
nHackel Apr 11, 2024
52361c2
Add GPU support for SamplingOp
nHackel Apr 11, 2024
44dc932
Add GPU support for NormalOp
nHackel Apr 11, 2024
438f7ff
Lessen restriction on ProdOp, NormalOp
nHackel Apr 15, 2024
161d527
Add GPU support to NFFTOp
nHackel Apr 16, 2024
7ea715c
NormalOp use storage_type of parent
nHackel Apr 16, 2024
60f81a1
Add GPU support to NormalOpNFFTToeplitz
nHackel Apr 16, 2024
4ee1f31
Allow kwargs for normalOperator(...)
nHackel Apr 16, 2024
31095ce
Fix copy for NormalOp
nHackel Apr 18, 2024
62a8c8e
Add conversion to dense array to WaveletOp
nHackel Apr 18, 2024
7495aa6
Add ProdNormalOp (migration MRIReco
nHackel Apr 18, 2024
135b5c4
Fix include order for ProdOp
nHackel Apr 18, 2024
54d5fe5
Migrate DiagOp from MRIReco with GPU support
nHackel Apr 22, 2024
3e2b95d
Fix copy NFFTOp
nHackel Apr 23, 2024
e828d88
Fix DiagOp constructor and normalOp
nHackel Apr 23, 2024
8cb2bb4
Allow WeightingOp weights vec to be GPU arrays
nHackel Apr 23, 2024
f562a4a
Add operatore copy function as kwarg
nHackel Apr 24, 2024
c261108
Pass along normalOperator kwargs (for FFT flags)
nHackel Apr 24, 2024
b2489a3
Reduce allocation in DiagOp normalOperator
nHackel Apr 24, 2024
af02185
Fix tmp array construction for FFTOp on UnionAll storage vectors
nHackel Apr 25, 2024
af27ffc
Fix normalOp constructor to use eltype of storage_type
nHackel Apr 25, 2024
671bbcd
Improve WaveletOp between GPU and CPU
nHackel Apr 25, 2024
2b1331d
Fix missing ; in SamplingOp
nHackel Apr 25, 2024
6dfdd5b
Add RadonOp based on RadonKA
nHackel Apr 30, 2024
19d74d5
Fix eltype of SamplingOp
nHackel May 28, 2024
8811386
Init updating tests
nHackel Jun 3, 2024
7ae7ddb
Fix bugs in FFTOp and SamplingOp
nHackel Jun 4, 2024
82d41a3
Add tests and bugfixes for DiagOp
nHackel Jun 4, 2024
309a61a
Add RadonOp test
nHackel Jun 4, 2024
ef5bcd2
Add breakage workflow
nHackel Jun 5, 2024
46130c1
Fix branch name
nHackel Jun 5, 2024
ed3a934
Attempt to let breakage fail if tests fail
nHackel Jun 5, 2024
c278627
Improve @testset handling
nHackel Jun 5, 2024
2475c89
Test setup for CUDA
nHackel Jun 6, 2024
ee3b3fb
Use fill for res in GradOp
nHackel Jun 10, 2024
27188c6
Add LinearOperatorException for non-concrete S in Diag, Normal and Pr…
nHackel Jun 19, 2024
885b7c5
Improve GradOp performance on GPU for multiple dims
nHackel Jun 21, 2024
44994a5
Add CUDA buildkite
nHackel Jun 27, 2024
a97ed90
Try fixing buildkite
nHackel Jun 27, 2024
ef24731
Fix julia version in buildkite
nHackel Jun 27, 2024
e06506c
Add CuNFFT to CUDA buildkite
nHackel Jun 27, 2024
5e66c00
Add all extras to buildkite
nHackel Jun 27, 2024
1423c54
Use TestEnv to fix compat issues for buildkite
nHackel Jun 27, 2024
8e14f5d
Readd CuNFFT to buildkite
nHackel Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 23 additions & 35 deletions ext/LinearOperatorFFTWExt/FFTOp.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export FFTOpImpl

mutable struct FFTOpImpl{T} <: FFTOp{T}
mutable struct FFTOpImpl{T, vecT, P <: AbstractFFTs.Plan{T}, IP <: AbstractFFTs.Plan{T}} <: FFTOp{T}
nrow :: Int
ncol :: Int
symmetric :: Bool
Expand All @@ -14,10 +14,10 @@ mutable struct FFTOpImpl{T} <: FFTOp{T}
args5 :: Bool
use_prod5! :: Bool
allocated5 :: Bool
Mv5 :: Vector{T}
Mtu5 :: Vector{T}
plan
iplan
Mv5 :: vecT
Mtu5 :: vecT
plan :: P
iplan :: IP
shift::Bool
unitary::Bool
end
Expand All @@ -34,13 +34,14 @@ returns an operator which performs an FFT on Arrays of type T
* `shape::Tuple` - size of the array to transform
* (`shift=true`) - if true, fftshifts are performed
* (`unitary=true`) - if true, FFT is normalized such that it is unitary
* (`S = Vector{T}`) - type of temporary vector, change to use on GPU
* (`kwargs...`) - keyword arguments given to fft plan
"""
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, cuda::Bool=false) where D
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D
function LinearOperatorCollection.FFTOp(T::Type{<:Number}; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D

probably you want Numbers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point! The restrictions on T are a bit inconsistent across the package. I will do a pass with this change across all operators once I am done with the GPU changes, potentially I might do this in another PR


#tmpVec = cuda ? CuArray{T}(undef,shape) : Array{Complex{real(T)}}(undef, shape)
tmpVec = Array{Complex{real(T)}}(undef, shape)
plan = plan_fft!(tmpVec; flags=FFTW.MEASURE)
iplan = plan_bfft!(tmpVec; flags=FFTW.MEASURE)
tmpVec = S(undef, shape...)
plan = plan_fft!(tmpVec; kwargs...)
iplan = plan_bfft!(tmpVec; kwargs...)

if unitary
facF = T(1.0/sqrt(prod(shape)))
Expand All @@ -50,39 +51,26 @@ function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::
facB = T(1.0)
end

let shape_=shape, plan_=plan, iplan_=iplan, tmpVec_=tmpVec, facF_=facF, facB_=facB
let shape_ = shape, plan_ = plan, iplan_ = iplan, tmpVec_ = tmpVec, facF_ = facF, facB_ = facB

if shift
return FFTOpImpl{T}(prod(shape), prod(shape), false, false
, (res, x) -> fft_multiply_shift!(res, plan_, x, shape_, facF_, tmpVec_)
, nothing
, (res, x) -> fft_multiply_shift!(res, iplan_, x, shape_, facB_, tmpVec_)
, 0, 0, 0, true, false, true, T[], T[]
, plan
, iplan
, shift
, unitary)
else
return FFTOpImpl{T}(prod(shape), prod(shape), false, false
, (res, x) -> fft_multiply!(res, plan_, x, facF_, tmpVec_)
, nothing
, (res, x) -> fft_multiply!(res, iplan_, x, facB_, tmpVec_)
, 0, 0, 0, true, false, true, T[], T[]
, plan
, iplan
, shift
, unitary)
end
fun! = fft_multiply!
if shift
fun! = fft_multiply_shift!
end

return FFTOpImpl(prod(shape), prod(shape), false, false, (res, x) -> fun!(res, plan_, x, shape_, facF_, tmpVec_),
nothing, (res, x) -> fun!(res, iplan_, x, shape_, facB_, tmpVec_),
0, 0, 0, true, false, true, similar(tmpVec, 0), similar(tmpVec, 0), plan, iplan, shift, unitary)
end
end

function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, factor::T, tmpVec::Array{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, factor::T, tmpVec::AbstractArray{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
tmpVec[:] .= x
plan * tmpVec
res .= factor .* vec(tmpVec)
end

function fft_multiply_shift!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, shape::NTuple{D}, factor::T, tmpVec::Array{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
function fft_multiply_shift!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, shape::NTuple{D}, factor::T, tmpVec::AbstractArray{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
ifftshift!(tmpVec, reshape(x,shape))
plan * tmpVec
fftshift!(reshape(res,shape), tmpVec)
Expand All @@ -91,5 +79,5 @@ end


function Base.copy(S::FFTOpImpl)
return FFTOp(eltype(S); shape=size(S.plan), shift=S.shift, unitary=S.unitary)
return FFTOp(eltype(S); shape=size(S.plan), shift=S.shift, unitary=S.unitary, S = LinearOperators.storage_type(S)) # TODO loses kwargs...
end
2 changes: 1 addition & 1 deletion ext/LinearOperatorFFTWExt/LinearOperatorFFTWExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LinearOperatorFFTWExt

using LinearOperatorCollection, FFTW
using LinearOperatorCollection, FFTW, FFTW.AbstractFFTs

include("FFTOp.jl")
include("DCTOp.jl")
Expand Down
Loading