Skip to content

Commit

Permalink
bugfix for dot of Hermitian{noncommutative} (#52333)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <[email protected]>
(cherry picked from commit 53f1eb8)
  • Loading branch information
stevengj authored and KristofferC committed Dec 12, 2023
1 parent ea88b8c commit e94785f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ function triu(A::Symmetric, k::Integer=0)
end
end

for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:Hermitian, :adjoint, :real)]
for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Union{Real,Complex}}), :adjoint, :real)]
@eval begin
function dot(A::$T, B::$T)
n = size(A, 2)
Expand Down
16 changes: 16 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ module TestSymmetric

using Test, LinearAlgebra, Random

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
using .Main.Quaternions

Random.seed!(1010)

@testset "Pauli σ-matrices: " for σ in map(Hermitian,
Expand Down Expand Up @@ -462,6 +467,17 @@ end
end
end

# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
# or any number type where conj(a)*conj(b) ≠ conj(a*b):
@testset "dot Hermitian quaternion #52318" begin
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
@test A == Hermitian(A) && B == Hermitian(B)
@test dot(A, B) dot(Hermitian(A), Hermitian(B))
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
@test A == Symmetric(A) && B == Symmetric(B)
@test dot(A, B) dot(Symmetric(A), Symmetric(B))
end

#Issue #7647: test xsyevr, xheevr, xstevr drivers.
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
(Symmetric(diagm(0 => 1.0:3.0)),
Expand Down
5 changes: 4 additions & 1 deletion test/testhelpers/Quaternions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3))
Base.abs(q::Quaternion) = sqrt(abs2(q))
Base.real(::Type{Quaternion{T}}) where {T} = T
Base.real(q::Quaternion) = q.s
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)
Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T))
Expand All @@ -33,7 +34,9 @@ Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*
q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1,
q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s)
Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r)
Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity
Base.:(*)(q::Quaternion, r::Bool) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) # remove method ambiguity
Base.:(*)(r::Real, q::Quaternion) = q * r
Base.:(*)(r::Bool, q::Quaternion) = q * r # remove method ambiguity
Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w))
Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q))

Expand Down

0 comments on commit e94785f

Please sign in to comment.