Skip to content

Commit

Permalink
further adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed May 8, 2024
1 parent 24b2cf8 commit 48f3784
Showing 1 changed file with 11 additions and 25 deletions.
36 changes: 11 additions & 25 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,19 @@ generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::Abstra
spdensemul!(C, tA, tB, A, B, alpha, beta)
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, alpha::Number, beta::Number) =
spdensemul!(C, tA, 'N', A, B, alpha, beta)
# legacy methods: TODO: remove
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add.alpha, _add.beta)
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add.alpha, _add.beta)
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add.alpha, _add.beta)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
if tA == 'N'
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
if tA_uc == 'N'
_spmatmul!(C, A, wrap(B, tB), alpha, beta)
elseif tA == 'T'
elseif tA_uc == 'T'
_At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), alpha, beta)
elseif tA == 'C'
elseif tA_uc == 'C'
_At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), alpha, beta)
elseif tA in ('S', 's', 'H', 'h') && tB == 'N'
elseif tA_uc in ('S', 'H') && tB_uc == 'N'
rangefun = isuppercase(tA) ? nzrangeup : nzrangelo
diagop = tA in ('S', 's') ? identity : real
odiagop = tA in ('S', 's') ? transpose : adjoint
diagop = tA_uc == 'S' ? identity : real
odiagop = tA_uc == 'S' ? transpose : adjoint
T = eltype(C)
_mul!(rangefun, diagop, odiagop, C, A, B, T(alpha), T(beta))
else
Expand Down Expand Up @@ -123,9 +117,6 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
C
end

# TODO:remove
generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul) =
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
Expand Down Expand Up @@ -328,17 +319,12 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer,
p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k
end

# TODO: remove this one method
Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2, B::SparseMatrixCSCUnion2, _add::MulAddMul)
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end
Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2,
B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
@stable_muladdmul _generic_matmatmul!(C, tA, tB, A, B, MulAddMul(alpha, beta))
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
Bnew, tb = tB_uc in ('S', 'H') ? (wrap(B, tB), oftype(tB, 'N')) : (B, tB)
@stable_muladdmul _generic_matmatmul!(C, ta, tb, Anew, Bnew, MulAddMul(alpha, beta))
end
function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat,
B::AbstractVecOrMat, _add::MulAddMul)
Expand Down

0 comments on commit 48f3784

Please sign in to comment.