From 88f0cbc833f37e97cb67069e4bc0b0298ace6f78 Mon Sep 17 00:00:00 2001 From: nHackel Date: Fri, 19 Jul 2024 12:05:05 +0000 Subject: [PATCH] Workaround for promote_type for GPU arrays --- src/LinearOperatorCollection.jl | 31 +++++++++++++++++++++++++++++++ src/NormalOp.jl | 3 +-- src/ProdOp.jl | 3 +-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/LinearOperatorCollection.jl b/src/LinearOperatorCollection.jl index 370d4d8..a4e2ada 100644 --- a/src/LinearOperatorCollection.jl +++ b/src/LinearOperatorCollection.jl @@ -108,4 +108,35 @@ include("SamplingOp.jl") include("NormalOp.jl") include("DiagOp.jl") +function promote_storage_types(A, B) + A_type = storage_type(A) + B_type = storage_type(B) + S = promote_type(A_type, B_type) + if !isconcretetype(S) + # Find common eltype + elType = promote_type(eltype(A), eltype(B)) + if !isconcretetype(elType) + throw(LinearOperatorException("Storage types cannot be promoted to a concrete type")) + end + + # Same base type + A_base = Base.typename(A_type).wrapper + B_base = Base.typename(B_type).wrapper + if A_base != B_base + throw(LinearOperatorException("Storage types cannot be promoted to a common base type")) + end + + # LinearOperators only accepts DataTypes, so we cant just do A_base{elType}, since that might be a UnionAll + # Check if either A_type or B_type have the fitting eltype + if eltype(A_type) == elType + S = A_type + elseif eltype(B_type) == elType + S = B_type + else + throw(LinearOperatorException("Storage types cannot be promoted to a common eltype")) + end + end + return S +end + end diff --git a/src/NormalOp.jl b/src/NormalOp.jl index 9b09fc7..14655fe 100644 --- a/src/NormalOp.jl +++ b/src/NormalOp.jl @@ -52,8 +52,7 @@ end LinearOperators.storage_type(op::NormalOpImpl) = typeof(op.Mv5) function NormalOpImpl(parent, weights) - S = promote_type(storage_type(parent), storage_type(weights)) - isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type")) + S = promote_storage_types(parent, weights) tmp = S(undef, size(parent, 1)) return NormalOpImpl(parent, weights, tmp) end diff --git a/src/ProdOp.jl b/src/ProdOp.jl index 24e7866..84c48a1 100644 --- a/src/ProdOp.jl +++ b/src/ProdOp.jl @@ -36,8 +36,7 @@ composition/product of two Operators. Differs with * since it can handle normal function ProdOp(A, B) nrow = size(A, 1) ncol = size(B, 2) - S = promote_type(LinearOperators.storage_type(A), LinearOperators.storage_type(B)) - isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type")) + S = promote_storage_types(A, B) tmp_ = S(undef, size(B, 1)) function produ!(res, x::AbstractVector{T}, tmp) where T<:Union{Real,Complex}