-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU Support for Operators #9
Conversation
""" | ||
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, cuda::Bool=false) where D | ||
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D | |
function LinearOperatorCollection.FFTOp(T::Type{<:Number}; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D |
probably you want Numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point! The restrictions on T are a bit inconsistent across the package. I will do a pass with this change across all operators once I am done with the GPU changes, potentially I might do this in another PR
This PR adds GPU support for (some) of the operators. I will update a list of which operators were ported and how they were changed.
Operators:
Changes:
LinearOperators has a keyword argument
S
which describes the "storage_type" of the operator. This has to be adapted from the defaultVector{T}
to aCuArray
or another GPU array to make an operator work on the GPU. It is usually just thetypeof(...)
of the vector to which the operator will be applied. I've started adding this keyword argument to all operators/constructors defined in this package to stay consistent, though not all operators will be able to work on the GPU. The 'S' kwarg does not seem to be applicable to operators that are compositions of existing ones such as theProdOp
andNormalOp
FFTOp
: Did not require any specific CUDA dependency so far. I could just refactor the struct a little bit to allow for theS
kwarg and keep track of this information. One issue is that the CUDA FFT does not allow for FFTW.FLAGS to be set (the operator usedFFTW.MEASURE
previously). My work around for now is to give the constructorkwargs...
which it passes on to the plan call. Then the caller can decide if a FLAG should be used or not. Alternatively we dispatch on S and depending on that we load a different plan/different plan arguments.WaveletOp
: Wavelets.jl does not seem to work on a GPU. The operator now carries two dense arrays to which it assigns any "non-dense" arguments and then does the transformation., i.e. it computes on the CPU and then turns the result back to a GPU arrayGradientOp
: I've added an extension for GPUArrays.jl and added new dispatch on GPUArrays for thegrad!
methods of this operator.ProdOp
: Did not receive anS
kwarg, instead the operator was restricted to work on operators that implementstorage_type
.. At the moment LinearOperator(gpuArray) does not derive a correct storage type at the moment. I have an open PR that should change that (atm only for CUDA).SamplingOp
: Added anS
kwarg, however this requires a change downstream in LinearOperators.jl, since theopRestriction
this operator is (partially) based on works on its own with GPUs, but cannot be combined with other operators to work on the GPU. I've added PR to LinearOperators.jl. I've also removed a superfluous (I think) opEye.WeightingOp
: Works out of the box, did not (yet) add aWeightingOps(...; kwargs...)
to "accept"/ignore a potentialS
kwarg.NormalOp
: Similar toProdOp
. I have also slightly rearranged the call order of the constructors. I think this is an operator that is being inspected by MRIReco.jl, so I might need to revisit the API again when I adapt the interface there. If the limitation to LinearOperators is an issue, we could also overload thestorage_type
call on the matrices themselves. I've added this as an option to my PR in LinearOperators.jl. I've also reused the WeightingOp here, one could go a step further and collapse the wholeNormalOp
into aProdOp
.NFFTOp
: Similar to FFTOp, though the GPU seems to accept the same kwargs as the CPU version so I did not change itNFFTToeplitzNormalOp
: Similar to FFTOp, this time I had to remove the FFTW.FLAGS