From 5bb976646b2f133d4854d9674519e9558c51ee49 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 01/16] Add rules and tests --- src/rulesets/LinearAlgebra/dense.jl | 48 ++++++++++++++++++++++++++++ test/rulesets/LinearAlgebra/dense.jl | 8 +++++ 2 files changed, 56 insertions(+) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index a5edd6cd5..810c78b5a 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -394,3 +394,51 @@ function rrule( end return Ω, lyap_pullback end + +##### +##### `kron` +##### + +function frule((_, Δx, Δy), ::typeof(kron), x, y) + return kron(x, y), kron(Δx, y) + kron(x, Δy) +end + +function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) + z = kron(x, y) + + function kron_pullback(z̄) + x̄ = zero(x) + ȳ = zero(y) + m = firstindex(z̄) + @inbounds for j in axes(x,2), i in axes(x,1) + xij = x[i,j] + for k in eachindex(y) + x̄[i, j] += y[k]' * z̄[m] + ȳ[k] += xij * z̄[m] + m += 1 + end + end + NoTangent(), x̄, ȳ + end + z, kron_pullback +end + +function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix) + z = kron(x, y) + + function kron_pullback(z̄) + x̄ = zero(x) + ȳ = zero(y) + m = firstindex(z̄) + @inbounds for l in axes(y,2), i in eachindex(x) + xi = x[i] + for k in axes(y,1) + x̄[i] += y[k, l]' * z̄[m] + ȳ[k, l] += xi * z̄[m] + m += 1 + end + end + NoTangent(), x̄, ȳ + end + z, kron_pullback +end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 5f5efa8d2..8dcbd6edc 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -159,4 +159,12 @@ test_rrule(lyap, A, C) end end + @testset "kron" begin + @testset "AbstractVecOrMat{$T}" for T in (Float64, ComplexF64) + test_frule(kron, randn(T, 3), randn(T, 3)) + test_frule(kron, randn(T, 3, 2), randn(T, 3)) + test_frule(kron, randn(T, 3), randn(T, 3, 4)) + test_frule(kron, randn(T, 3, 4), randn(T, 2, 2)) + end + end end From f0902e3342914dc3d480225e2a8c1546a1504f2f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 02/16] Add tests for `rrule` --- test/rulesets/LinearAlgebra/dense.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 8dcbd6edc..709780cb2 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -161,10 +161,18 @@ end @testset "kron" begin @testset "AbstractVecOrMat{$T}" for T in (Float64, ComplexF64) - test_frule(kron, randn(T, 3), randn(T, 3)) - test_frule(kron, randn(T, 3, 2), randn(T, 3)) - test_frule(kron, randn(T, 3), randn(T, 3, 4)) - test_frule(kron, randn(T, 3, 4), randn(T, 2, 2)) + @testset "frule" begin + test_frule(kron, randn(T, 3), randn(T, 3)) + test_frule(kron, randn(T, 3, 2), randn(T, 3)) + test_frule(kron, randn(T, 3), randn(T, 3, 4)) + test_frule(kron, randn(T, 3, 4), randn(T, 2, 2)) + end + @testset "rrule" begin + test_rrule(kron, randn(T, 3), randn(T, 3)) + test_rrule(kron, randn(T, 3, 2), randn(T, 3)) + test_rrule(kron, randn(T, 3), randn(T, 3, 4)) + test_rrule(kron, randn(T, 3, 4), randn(T, 2, 2)) + end end end end From 7c53f4b52bf1a5ad16f4281a2fa5e7f98cca717f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 03/16] Add rules and try to cover complex case --- src/rulesets/LinearAlgebra/dense.jl | 48 ++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 810c78b5a..003054cca 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -403,9 +403,25 @@ function frule((_, Δx, Δy), ::typeof(kron), x, y) return kron(x, y), kron(Δx, y) + kron(x, Δy) end -function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) - z = kron(x, y) +function rrule(::typeof(kron), x::AbstractVector, y::AbstractVector) + function kron_pullback(z̄) + x̄ = zero(x) + ȳ = zero(y) + m = firstindex(z̄) + @inbounds for i in eachindex(x) + xi = x[i] + for k in eachindex(y) + x̄[i] += y[k]' * z̄[m] + ȳ[k] += xi' * z̄[m] + m += 1 + end + end + NoTangent(), x̄, ȳ + end + kron(x, y), kron_pullback +end +function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) function kron_pullback(z̄) x̄ = zero(x) ȳ = zero(y) @@ -414,18 +430,16 @@ function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) xij = x[i,j] for k in eachindex(y) x̄[i, j] += y[k]' * z̄[m] - ȳ[k] += xij * z̄[m] + ȳ[k] += xij' * z̄[m] m += 1 end end NoTangent(), x̄, ȳ end - z, kron_pullback + kron(x, y), kron_pullback end function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix) - z = kron(x, y) - function kron_pullback(z̄) x̄ = zero(x) ȳ = zero(y) @@ -434,11 +448,29 @@ function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix) xi = x[i] for k in axes(y,1) x̄[i] += y[k, l]' * z̄[m] - ȳ[k, l] += xi * z̄[m] + ȳ[k, l] += xi' * z̄[m] + m += 1 + end + end + NoTangent(), x̄, ȳ + end + kron(x, y), kron_pullback +end + +function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractMatrix) + function kron_pullback(z̄) + x̄ = zero(x) + ȳ = zero(y) + m = firstindex(z̄) + @inbounds for l in axes(y,2), j in axes(x,2), i in axes(x,1) + xij = x[i, j] + for k in axes(y,1) + x̄[i, j] += y[k, l]' * z̄[m] + ȳ[k, l] += xij' * z̄[m] m += 1 end end NoTangent(), x̄, ȳ end - z, kron_pullback + kron(x, y), kron_pullback end From c1226eb70ceef9f26102fe9d9ef4484b8f7bc238 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 04/16] Restrict types of arguments Co-authored-by: David Widmann --- src/rulesets/LinearAlgebra/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 003054cca..a83a6060e 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -399,7 +399,7 @@ end ##### `kron` ##### -function frule((_, Δx, Δy), ::typeof(kron), x, y) +function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number}) return kron(x, y), kron(Δx, y) + kron(x, Δy) end From 236daf1909259cf57a587bfd16169f5c8ebecd68 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 05/16] Write rules functionally and fix them --- src/rulesets/LinearAlgebra/dense.jl | 78 ++++++++--------------------- 1 file changed, 22 insertions(+), 56 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index a83a6060e..17f11a5f9 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -403,74 +403,40 @@ function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y:: return kron(x, y), kron(Δx, y) + kron(x, Δy) end -function rrule(::typeof(kron), x::AbstractVector, y::AbstractVector) +function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) function kron_pullback(z̄) - x̄ = zero(x) - ȳ = zero(y) - m = firstindex(z̄) - @inbounds for i in eachindex(x) - xi = x[i] - for k in eachindex(y) - x̄[i] += y[k]' * z̄[m] - ȳ[k] += xi' * z̄[m] - m += 1 - end - end - NoTangent(), x̄, ȳ + dz = reshape(z̄, length(y), length(x)) + return NoTangent(), conj.(dz' * y), dz * conj.(x) end - kron(x, y), kron_pullback + return kron(x, y), kron_pullback end -function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector) +function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) function kron_pullback(z̄) - x̄ = zero(x) - ȳ = zero(y) - m = firstindex(z̄) - @inbounds for j in axes(x,2), i in axes(x,1) - xij = x[i,j] - for k in eachindex(y) - x̄[i, j] += y[k]' * z̄[m] - ȳ[k] += xij' * z̄[m] - m += 1 - end - end - NoTangent(), x̄, ȳ + dz = reshape(z̄, length(y), size(x)...) + x̄ = Ref(y') .* eachslice(dz; dims = (2, 3)) + ȳ = conj.(dot.(eachslice(dz; dims = 1), Ref(x))) + return NoTangent(), x̄, ȳ end - kron(x, y), kron_pullback + return kron(x, y), kron_pullback end -function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix) +function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) function kron_pullback(z̄) - x̄ = zero(x) - ȳ = zero(y) - m = firstindex(z̄) - @inbounds for l in axes(y,2), i in eachindex(x) - xi = x[i] - for k in axes(y,1) - x̄[i] += y[k, l]' * z̄[m] - ȳ[k, l] += xi' * z̄[m] - m += 1 - end - end - NoTangent(), x̄, ȳ + dz = reshape(z̄, size(y, 1), length(x), size(y, 2)) + x̄ = conj.(dot.(eachslice(dz; dims = 2), Ref(y))) + ȳ = Ref(x') .* eachslice(dz; dims = (1, 3)) + return NoTangent(), x̄, ȳ end - kron(x, y), kron_pullback + return kron(x, y), kron_pullback end -function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractMatrix) +function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) function kron_pullback(z̄) - x̄ = zero(x) - ȳ = zero(y) - m = firstindex(z̄) - @inbounds for l in axes(y,2), j in axes(x,2), i in axes(x,1) - xij = x[i, j] - for k in axes(y,1) - x̄[i, j] += y[k, l]' * z̄[m] - ȳ[k, l] += xij' * z̄[m] - m += 1 - end - end - NoTangent(), x̄, ȳ + dz = reshape(z̄, size(y, 1), size(x, 1), size(y, 2), size(x, 2)) + x̄ = conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) + ȳ = dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x))) + return NoTangent(), x̄, ȳ end - kron(x, y), kron_pullback + return kron(x, y), kron_pullback end From b2d4f4a2dfd47f8c4e12e49087ec964c5226e71b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 06/16] Add `unthunk` and `@thunk` --- src/rulesets/LinearAlgebra/dense.jl | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 17f11a5f9..463aa0975 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -405,17 +405,19 @@ end function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) function kron_pullback(z̄) - dz = reshape(z̄, length(y), length(x)) - return NoTangent(), conj.(dz' * y), dz * conj.(x) + dz = reshape(unthunk(z̄), length(y), length(x)) + x̄ = @thunk conj.(dz' * y) + ȳ = @thunk dz * conj.(x) + return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback end function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) function kron_pullback(z̄) - dz = reshape(z̄, length(y), size(x)...) - x̄ = Ref(y') .* eachslice(dz; dims = (2, 3)) - ȳ = conj.(dot.(eachslice(dz; dims = 1), Ref(x))) + dz = reshape(unthunk(z̄), length(y), size(x)...) + x̄ = @thunk Ref(y') .* eachslice(dz; dims = (2, 3)) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback @@ -423,9 +425,9 @@ end function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) function kron_pullback(z̄) - dz = reshape(z̄, size(y, 1), length(x), size(y, 2)) - x̄ = conj.(dot.(eachslice(dz; dims = 2), Ref(y))) - ȳ = Ref(x') .* eachslice(dz; dims = (1, 3)) + dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) + x̄ = @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y))) + ȳ = @thunk Ref(x') .* eachslice(dz; dims = (1, 3)) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback @@ -433,9 +435,9 @@ end function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) function kron_pullback(z̄) - dz = reshape(z̄, size(y, 1), size(x, 1), size(y, 2), size(x, 2)) - x̄ = conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) - ȳ = dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x))) + dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) + x̄ = @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) + ȳ = @thunk dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback From 8b94cfc70c1837f81d31e741d01aa460d004dc7e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 07/16] Change dimensions to make them recognizable --- test/rulesets/LinearAlgebra/dense.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 709780cb2..7dd7dfdf4 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -162,16 +162,16 @@ @testset "kron" begin @testset "AbstractVecOrMat{$T}" for T in (Float64, ComplexF64) @testset "frule" begin - test_frule(kron, randn(T, 3), randn(T, 3)) - test_frule(kron, randn(T, 3, 2), randn(T, 3)) - test_frule(kron, randn(T, 3), randn(T, 3, 4)) - test_frule(kron, randn(T, 3, 4), randn(T, 2, 2)) + test_frule(kron, randn(T, 2), randn(T, 3)) + test_frule(kron, randn(T, 2, 3), randn(T, 5)) + test_frule(kron, randn(T, 2), randn(T, 3, 5)) + test_frule(kron, randn(T, 2, 3), randn(T, 5, 7)) end @testset "rrule" begin - test_rrule(kron, randn(T, 3), randn(T, 3)) - test_rrule(kron, randn(T, 3, 2), randn(T, 3)) - test_rrule(kron, randn(T, 3), randn(T, 3, 4)) - test_rrule(kron, randn(T, 3, 4), randn(T, 2, 2)) + test_rrule(kron, randn(T, 2), randn(T, 3)) + test_rrule(kron, randn(T, 2, 3), randn(T, 5)) + test_rrule(kron, randn(T, 2), randn(T, 3, 5)) + test_rrule(kron, randn(T, 2, 3), randn(T, 5, 7)) end end end From b71b8efc2660f44ae9fd96b69da4e7b490eebce8 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 08/16] Further simplify rules --- src/rulesets/LinearAlgebra/dense.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 463aa0975..e7a9b9414 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -416,7 +416,7 @@ end function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) function kron_pullback(z̄) dz = reshape(unthunk(z̄), length(y), size(x)...) - x̄ = @thunk Ref(y') .* eachslice(dz; dims = (2, 3)) + x̄ = @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y))) ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x))) return NoTangent(), x̄, ȳ end @@ -427,7 +427,7 @@ function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<: function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) x̄ = @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y))) - ȳ = @thunk Ref(x') .* eachslice(dz; dims = (1, 3)) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback @@ -437,7 +437,7 @@ function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<: function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) x̄ = @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) - ȳ = @thunk dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x))) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback From 2ad5473ba30ed35339dd8f269f6f8091d293a85c Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 09/16] Only define rules for Julia 1.9 onwards --- src/rulesets/LinearAlgebra/dense.jl | 70 +++++++++++++++-------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index e7a9b9414..fbd6a6e5a 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -399,46 +399,48 @@ end ##### `kron` ##### -function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number}) - return kron(x, y), kron(Δx, y) + kron(x, Δy) -end +@static if VERSION ≥ v"1.9.0" + function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number}) + return kron(x, y), kron(Δx, y) + kron(x, Δy) + end -function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) - function kron_pullback(z̄) - dz = reshape(unthunk(z̄), length(y), length(x)) - x̄ = @thunk conj.(dz' * y) - ȳ = @thunk dz * conj.(x) - return NoTangent(), x̄, ȳ + function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), length(y), length(x)) + x̄ = @thunk conj.(dz' * y) + ȳ = @thunk dz * conj.(x) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback end - return kron(x, y), kron_pullback -end -function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) - function kron_pullback(z̄) - dz = reshape(unthunk(z̄), length(y), size(x)...) - x̄ = @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y))) - ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x))) - return NoTangent(), x̄, ȳ + function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), length(y), size(x)...) + x̄ = @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y))) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback end - return kron(x, y), kron_pullback -end -function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) - function kron_pullback(z̄) - dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) - x̄ = @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y))) - ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) - return NoTangent(), x̄, ȳ + function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) + x̄ = @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y))) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback end - return kron(x, y), kron_pullback -end -function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) - function kron_pullback(z̄) - dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) - x̄ = @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) - ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) - return NoTangent(), x̄, ȳ + function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) + function kron_pullback(z̄) + dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) + x̄ = @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) + ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) + return NoTangent(), x̄, ȳ + end + return kron(x, y), kron_pullback end - return kron(x, y), kron_pullback end From fde509eb1781d112cdc87cacc8d2c3f776028de4 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 10/16] Add projections --- src/rulesets/LinearAlgebra/dense.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index fbd6a6e5a..c8c758a2f 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -405,10 +405,12 @@ end end function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function kron_pullback(z̄) dz = reshape(unthunk(z̄), length(y), length(x)) - x̄ = @thunk conj.(dz' * y) - ȳ = @thunk dz * conj.(x) + x̄ = @thunk(project_x(conj.(dz' * y))) + ȳ = @thunk(project_y(dz * conj.(x))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback From 43861438c22576e3f852cbf85db0c78301d1f36f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 11/16] Add projections and remove redundant `conj` calls --- src/rulesets/LinearAlgebra/dense.jl | 18 ++++++++++++------ test/rulesets/LinearAlgebra/dense.jl | 24 +++++++++++++++--------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index c8c758a2f..3f5e343a7 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -417,30 +417,36 @@ end end function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function kron_pullback(z̄) dz = reshape(unthunk(z̄), length(y), size(x)...) - x̄ = @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y))) - ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x))) + x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 3))))) + ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = 1)))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback end function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) - x̄ = @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y))) - ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) + x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = 2)))) + ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3))))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback end function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) + project_x = ProjectTo(x) + project_y = ProjectTo(y) function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) - x̄ = @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y))) - ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x))) + x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 4))))) + ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3))))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 7dd7dfdf4..942599139 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -160,18 +160,24 @@ end end @testset "kron" begin - @testset "AbstractVecOrMat{$T}" for T in (Float64, ComplexF64) + @testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64) @testset "frule" begin - test_frule(kron, randn(T, 2), randn(T, 3)) - test_frule(kron, randn(T, 2, 3), randn(T, 5)) - test_frule(kron, randn(T, 2), randn(T, 3, 5)) - test_frule(kron, randn(T, 2, 3), randn(T, 5, 7)) + test_frule(kron, randn(T1, 2), randn(T2, 3)) + test_frule(kron, randn(T1, 2, 3), randn(T2, 5)) + test_frule(kron, randn(T1, 2), randn(T2, 3, 5)) + test_frule(kron, randn(T1, 2, 3), randn(T2, 5, 7)) end @testset "rrule" begin - test_rrule(kron, randn(T, 2), randn(T, 3)) - test_rrule(kron, randn(T, 2, 3), randn(T, 5)) - test_rrule(kron, randn(T, 2), randn(T, 3, 5)) - test_rrule(kron, randn(T, 2, 3), randn(T, 5, 7)) + test_rrule(kron, randn(T1, 2), randn(T2, 3)) + + test_rrule(kron, Diagonal(randn(T1, 2)), randn(T2, 3)) + test_rrule(kron, randn(T1, 2, 3), randn(T2, 5)) + + test_rrule(kron, randn(T1, 2), randn(T2, 3, 5)) + test_rrule(kron, randn(T1, 2), Diagonal(randn(T2, 3))) + + test_rrule(kron, randn(T1, 2, 3), randn(T2, 5, 7)) + test_rrule(kron, Diagonal(randn(T1, 2)), Diagonal(randn(T2, 3))) end end end From 1b97828d7d43c097c72f8e9fdf5f8c49d415e613 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 12/16] Fix type instability --- src/rulesets/LinearAlgebra/dense.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 3f5e343a7..896fcf13e 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -445,10 +445,12 @@ end project_y = ProjectTo(y) function kron_pullback(z̄) dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) - x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 4))))) - ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3))))) + x̄ = @thunk(project_x(_dot_collect.(Ref(y), eachslice(dz; dims = (2, 4))))) + ȳ = @thunk(project_y(_dot_collect.(Ref(x), eachslice(dz; dims = (1, 3))))) return NoTangent(), x̄, ȳ end return kron(x, y), kron_pullback end -end + + _dot_collect(A::AbstractMatrix, B::SubArray) = dot(A, B) + _dot_collect(A::Diagonal, B::SubArray) = dot(A, collect(B)) From 5b74071b985364b46db3302450f4283afae4ffc8 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:52:22 +0100 Subject: [PATCH 13/16] Run tests only above Julia 1.9 Co-authored-by: Seth Axen --- test/rulesets/LinearAlgebra/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 942599139..c145dfdbf 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -159,7 +159,7 @@ test_rrule(lyap, A, C) end end - @testset "kron" begin + VERSION ≥ v"1.9.0" && @testset "kron" begin @testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64) @testset "frule" begin test_frule(kron, randn(T1, 2), randn(T2, 3)) From 72060d01a24990a451e75d5b2d253a3b1c0918d0 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:53:04 +0100 Subject: [PATCH 14/16] Fix typo --- src/rulesets/LinearAlgebra/dense.jl | 1 + test/rulesets/LinearAlgebra/dense.jl | 296 +++++++++++++-------------- test/runtests.jl | 36 ++-- 3 files changed, 167 insertions(+), 166 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 896fcf13e..f6e713a13 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -454,3 +454,4 @@ end _dot_collect(A::AbstractMatrix, B::SubArray) = dot(A, B) _dot_collect(A::Diagonal, B::SubArray) = dot(A, collect(B)) +end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index c145dfdbf..fdc42bb61 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -1,164 +1,164 @@ @testset "dense LinearAlgebra" begin - @testset "dot" begin - @testset "Vector{$T}" for T in (Float64, ComplexF64) - @gpu test_frule(dot, randn(T, 3), randn(T, 3)) - @gpu test_rrule(dot, randn(T, 3), randn(T, 3)) - end - @testset "Array{$T, 3}" for T in (Float64, ComplexF64) - test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) - test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) - end - @testset "mismatched shapes" begin - # forward - @gpu test_frule(dot, randn(3, 5), randn(5, 3)) - @gpu test_frule(dot, randn(15), randn(5, 3)) - # reverse - @gpu test_rrule(dot, randn(3, 5), randn(5, 3)) - @gpu test_rrule(dot, randn(15), randn(5, 3)) - end - @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64) - @gpu_broken test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) - @gpu test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) - end - permuteddimsarray(A) = PermutedDimsArray(A, (2,1)) - @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray) - A = F(rand(T, 4, 3)) ⊢ F(rand(T, 4, 3)) - test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) - test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) - end - @testset "different types" begin - test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2)) - test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2)) + # @testset "dot" begin + # @testset "Vector{$T}" for T in (Float64, ComplexF64) + # @gpu test_frule(dot, randn(T, 3), randn(T, 3)) + # @gpu test_rrule(dot, randn(T, 3), randn(T, 3)) + # end + # @testset "Array{$T, 3}" for T in (Float64, ComplexF64) + # test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) + # test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) + # end + # @testset "mismatched shapes" begin + # # forward + # @gpu test_frule(dot, randn(3, 5), randn(5, 3)) + # @gpu test_frule(dot, randn(15), randn(5, 3)) + # # reverse + # @gpu test_rrule(dot, randn(3, 5), randn(5, 3)) + # @gpu test_rrule(dot, randn(15), randn(5, 3)) + # end + # @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64) + # @gpu_broken test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) + # @gpu test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) + # end + # permuteddimsarray(A) = PermutedDimsArray(A, (2,1)) + # @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray) + # A = F(rand(T, 4, 3)) ⊢ F(rand(T, 4, 3)) + # test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) + # test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) + # end + # @testset "different types" begin + # test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2)) + # test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2)) - # Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407 - test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false) - end - end + # # Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407 + # test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false) + # end + # end - @testset "mul!" begin - test_frule(mul!, rand(4), rand(4, 5), rand(5)) - test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) - test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) + # @testset "mul!" begin + # test_frule(mul!, rand(4), rand(4, 5), rand(5)) + # test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) + # test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) - # Rule with α,β::Bool is only visually more complicated: - test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true) - test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true) - test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false) - test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false) + # # Rule with α,β::Bool is only visually more complicated: + # test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true) + # test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true) + # test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false) + # test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false) - # Rule with nontrivial α, β allocates A*B: - test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn()) - test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn()) - end + # # Rule with nontrivial α, β allocates A*B: + # test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn()) + # test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn()) + # end - @testset "cross" begin - test_frule(cross, randn(3), randn(3)) - test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) - test_rrule(cross, randn(3), randn(3)) - # No complex support for rrule(cross,... + # @testset "cross" begin + # test_frule(cross, randn(3), randn(3)) + # test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) + # test_rrule(cross, randn(3), randn(3)) + # # No complex support for rrule(cross,... - # mix types - test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7) - end - @testset "pinv" begin - @testset "$T" for T in (Float64, ComplexF64) - test_scalar(pinv, randn(T)) - @test frule((ZeroTangent(), randn(T)), pinv, zero(T))[2] ≈ zero(T) - @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T) - end - @testset "Vector{$T}" for T in (Float64, ComplexF64) - test_frule(pinv, randn(T, 3), 0.0) - test_frule(pinv, randn(T, 3), 0.0) + # # mix types + # test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7) + # end + # @testset "pinv" begin + # @testset "$T" for T in (Float64, ComplexF64) + # test_scalar(pinv, randn(T)) + # @test frule((ZeroTangent(), randn(T)), pinv, zero(T))[2] ≈ zero(T) + # @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T) + # end + # @testset "Vector{$T}" for T in (Float64, ComplexF64) + # test_frule(pinv, randn(T, 3), 0.0) + # test_frule(pinv, randn(T, 3), 0.0) - # Checking types. TODO do we still need this? - x = randn(T, 3) - ẋ = randn(T, 3) - Δy = copyto!(similar(pinv(x)), randn(T, 3)) - @test frule((ZeroTangent(), ẋ), pinv, x)[2] isa typeof(pinv(x)) - @test rrule(pinv, x)[2](Δy)[2] isa typeof(x) - end + # # Checking types. TODO do we still need this? + # x = randn(T, 3) + # ẋ = randn(T, 3) + # Δy = copyto!(similar(pinv(x)), randn(T, 3)) + # @test frule((ZeroTangent(), ẋ), pinv, x)[2] isa typeof(pinv(x)) + # @test rrule(pinv, x)[2](Δy)[2] isa typeof(x) + # end - @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) - test_frule(pinv, F(randn(T, 3))) - test_rrule(pinv, F(randn(T, 3))) + # @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) + # test_frule(pinv, F(randn(T, 3))) + # test_rrule(pinv, F(randn(T, 3))) - # Check types. - # TODO: Do we need this still? - x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3)) - y = pinv(x) - Δy = copyto!(similar(y), randn(T, 3)) + # # Check types. + # # TODO: Do we need this still? + # x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3)) + # y = pinv(x) + # Δy = copyto!(similar(y), randn(T, 3)) - y_fwd, ∂y_fwd = frule((ZeroTangent(), ẋ), pinv, x) - @test y_fwd isa typeof(y) - @test ∂y_fwd isa typeof(y) + # y_fwd, ∂y_fwd = frule((ZeroTangent(), ẋ), pinv, x) + # @test y_fwd isa typeof(y) + # @test ∂y_fwd isa typeof(y) - y_rev, back = rrule(pinv, x) - @test y_rev isa typeof(y) - @test back(Δy)[2] isa typeof(x) - end - @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), - m in 1:3, - n in 1:3 + # y_rev, back = rrule(pinv, x) + # @test y_rev isa typeof(y) + # @test back(Δy)[2] isa typeof(x) + # end + # @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), + # m in 1:3, + # n in 1:3 - test_frule(pinv, randn(T, m, n)) - test_rrule(pinv, randn(T, m, n)) - end - end - @testset "$f" for f in (det, logdet) - @testset "$f(::$T)" for T in (Float64, ComplexF64) - b = (f === logdet && T <: Real) ? abs(randn(T)) : randn(T) - test_scalar(f, b) - end - @testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64) - B = generate_well_conditioned_matrix(T, 4) - if f === logdet && float(T) <: Float32 - test_frule(f, B; atol=1e-5, rtol=1e-5) - test_rrule(f, B; atol=1e-5, rtol=1e-5) - else - test_frule(f, B) - test_rrule(f, B) - end - end - @testset "$f(complex determinant)" begin - B = randn(ComplexF64, 4, 4) - U = exp(B - B') - test_frule(f, U) - test_rrule(f, U) - end - @testset "gpu" begin - @gpu_broken test_rrule(f, reshape(1:9, 3, 3)+I*pi) - end - end - @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64) - B = randn(T, 4, 4) - test_frule(logabsdet, B) - test_rrule(logabsdet, B) - # test for opposite sign of determinant - test_frule(logabsdet, -B) - test_rrule(logabsdet, -B) - end - @testset "tr" begin - @gpu test_frule(tr, randn(4, 4)) - @gpu test_rrule(tr, randn(4, 4)) - end - @testset "sylvester" begin - @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3) - A = randn(T, m, m) - B = randn(T, n, n) - C = randn(T, m, n) - test_frule(sylvester, A, B, C) - test_rrule(sylvester, A, B, C) - end - end - @testset "lyap" begin - n = 3 - @testset "Float64" for T in (Float64, ComplexF64) - A = randn(T, n, n) - C = randn(T, n, n) - test_frule(lyap, A, C) - test_rrule(lyap, A, C) - end - end + # test_frule(pinv, randn(T, m, n)) + # test_rrule(pinv, randn(T, m, n)) + # end + # end + # @testset "$f" for f in (det, logdet) + # @testset "$f(::$T)" for T in (Float64, ComplexF64) + # b = (f === logdet && T <: Real) ? abs(randn(T)) : randn(T) + # test_scalar(f, b) + # end + # @testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64) + # B = generate_well_conditioned_matrix(T, 4) + # if f === logdet && float(T) <: Float32 + # test_frule(f, B; atol=1e-5, rtol=1e-5) + # test_rrule(f, B; atol=1e-5, rtol=1e-5) + # else + # test_frule(f, B) + # test_rrule(f, B) + # end + # end + # @testset "$f(complex determinant)" begin + # B = randn(ComplexF64, 4, 4) + # U = exp(B - B') + # test_frule(f, U) + # test_rrule(f, U) + # end + # @testset "gpu" begin + # @gpu_broken test_rrule(f, reshape(1:9, 3, 3)+I*pi) + # end + # end + # @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64) + # B = randn(T, 4, 4) + # test_frule(logabsdet, B) + # test_rrule(logabsdet, B) + # # test for opposite sign of determinant + # test_frule(logabsdet, -B) + # test_rrule(logabsdet, -B) + # end + # @testset "tr" begin + # @gpu test_frule(tr, randn(4, 4)) + # @gpu test_rrule(tr, randn(4, 4)) + # end + # @testset "sylvester" begin + # @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3) + # A = randn(T, m, m) + # B = randn(T, n, n) + # C = randn(T, m, n) + # test_frule(sylvester, A, B, C) + # test_rrule(sylvester, A, B, C) + # end + # end + # @testset "lyap" begin + # n = 3 + # @testset "Float64" for T in (Float64, ComplexF64) + # A = randn(T, n, n) + # C = randn(T, n, n) + # test_frule(lyap, A, C) + # test_rrule(lyap, A, C) + # end + # end VERSION ≥ v"1.9.0" && @testset "kron" begin @testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64) @testset "frule" begin diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..5853ba1ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,7 +50,7 @@ end include("test_helpers.jl") # This can't be skipped println() - test_method_tables() # Check the global method tables are consistent + # test_method_tables() # Check the global method tables are consistent # Each file puts all tests inside one or more @testset blocks include_test("rulesets/Base/CoreLogging.jl") @@ -64,30 +64,30 @@ end include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") - include_test("unzipped.jl") # used primarily for broadcast + # include_test("unzipped.jl") # used primarily for broadcast - println() + # println() - include_test("rulesets/Statistics/statistics.jl") + # include_test("rulesets/Statistics/statistics.jl") - println() + # println() include_test("rulesets/LinearAlgebra/dense.jl") - include_test("rulesets/LinearAlgebra/norm.jl") - include_test("rulesets/LinearAlgebra/matfun.jl") - include_test("rulesets/LinearAlgebra/structured.jl") - include_test("rulesets/LinearAlgebra/symmetric.jl") - include_test("rulesets/LinearAlgebra/factorization.jl") - include_test("rulesets/LinearAlgebra/blas.jl") - include_test("rulesets/LinearAlgebra/lapack.jl") - include_test("rulesets/LinearAlgebra/uniformscaling.jl") + # include_test("rulesets/LinearAlgebra/norm.jl") + # include_test("rulesets/LinearAlgebra/matfun.jl") + # include_test("rulesets/LinearAlgebra/structured.jl") + # include_test("rulesets/LinearAlgebra/symmetric.jl") + # include_test("rulesets/LinearAlgebra/factorization.jl") + # include_test("rulesets/LinearAlgebra/blas.jl") + # include_test("rulesets/LinearAlgebra/lapack.jl") + # include_test("rulesets/LinearAlgebra/uniformscaling.jl") - println() + # println() - include_test("rulesets/SparseArrays/sparsematrix.jl") + # include_test("rulesets/SparseArrays/sparsematrix.jl") - println() + # println() - include_test("rulesets/Random/random.jl") - println() + # include_test("rulesets/Random/random.jl") + # println() end From 6650adcfe3b7104fdd4ec90875aefa5a028af724 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 21:53:28 +0100 Subject: [PATCH 15/16] Enable all tests --- test/rulesets/LinearAlgebra/dense.jl | 296 +++++++++++++-------------- test/runtests.jl | 36 ++-- 2 files changed, 166 insertions(+), 166 deletions(-) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index fdc42bb61..c145dfdbf 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -1,164 +1,164 @@ @testset "dense LinearAlgebra" begin - # @testset "dot" begin - # @testset "Vector{$T}" for T in (Float64, ComplexF64) - # @gpu test_frule(dot, randn(T, 3), randn(T, 3)) - # @gpu test_rrule(dot, randn(T, 3), randn(T, 3)) - # end - # @testset "Array{$T, 3}" for T in (Float64, ComplexF64) - # test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) - # test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) - # end - # @testset "mismatched shapes" begin - # # forward - # @gpu test_frule(dot, randn(3, 5), randn(5, 3)) - # @gpu test_frule(dot, randn(15), randn(5, 3)) - # # reverse - # @gpu test_rrule(dot, randn(3, 5), randn(5, 3)) - # @gpu test_rrule(dot, randn(15), randn(5, 3)) - # end - # @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64) - # @gpu_broken test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) - # @gpu test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) - # end - # permuteddimsarray(A) = PermutedDimsArray(A, (2,1)) - # @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray) - # A = F(rand(T, 4, 3)) ⊢ F(rand(T, 4, 3)) - # test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) - # test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) - # end - # @testset "different types" begin - # test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2)) - # test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2)) + @testset "dot" begin + @testset "Vector{$T}" for T in (Float64, ComplexF64) + @gpu test_frule(dot, randn(T, 3), randn(T, 3)) + @gpu test_rrule(dot, randn(T, 3), randn(T, 3)) + end + @testset "Array{$T, 3}" for T in (Float64, ComplexF64) + test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) + test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5)) + end + @testset "mismatched shapes" begin + # forward + @gpu test_frule(dot, randn(3, 5), randn(5, 3)) + @gpu test_frule(dot, randn(15), randn(5, 3)) + # reverse + @gpu test_rrule(dot, randn(3, 5), randn(5, 3)) + @gpu test_rrule(dot, randn(15), randn(5, 3)) + end + @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64) + @gpu_broken test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) + @gpu test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4)) + end + permuteddimsarray(A) = PermutedDimsArray(A, (2,1)) + @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray) + A = F(rand(T, 4, 3)) ⊢ F(rand(T, 4, 3)) + test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) + test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3) + end + @testset "different types" begin + test_rrule(dot, rand(2), rand(2, 2), rand(ComplexF64, 2)) + test_rrule(dot, rand(2), Diagonal(rand(2)), rand(ComplexF64, 2)) - # # Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407 - # test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false) - # end - # end + # Inference failure due to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/407 + test_rrule(dot, Diagonal(rand(2)), rand(2, 2); check_inferred=false) + end + end - # @testset "mul!" begin - # test_frule(mul!, rand(4), rand(4, 5), rand(5)) - # test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) - # test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) + @testset "mul!" begin + test_frule(mul!, rand(4), rand(4, 5), rand(5)) + test_frule(mul!, rand(3, 3), rand(3, 3), rand(3, 3)) + test_frule(mul!, rand(3, 3), rand(), rand(3, 3)) - # # Rule with α,β::Bool is only visually more complicated: - # test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true) - # test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true) - # test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false) - # test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false) + # Rule with α,β::Bool is only visually more complicated: + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, true) + test_frule(mul!, rand(4), rand(4, 5), rand(5), false, true) + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, false) + test_frule(mul!, rand(4), rand(4, 5), rand(5), false, false) - # # Rule with nontrivial α, β allocates A*B: - # test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn()) - # test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn()) - # end + # Rule with nontrivial α, β allocates A*B: + test_frule(mul!, rand(4), rand(4, 5), rand(5), true, randn()) + test_frule(mul!, rand(4), rand(4, 5), rand(5), randn(), randn()) + end - # @testset "cross" begin - # test_frule(cross, randn(3), randn(3)) - # test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) - # test_rrule(cross, randn(3), randn(3)) - # # No complex support for rrule(cross,... + @testset "cross" begin + test_frule(cross, randn(3), randn(3)) + test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3)) + test_rrule(cross, randn(3), randn(3)) + # No complex support for rrule(cross,... - # # mix types - # test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7) - # end - # @testset "pinv" begin - # @testset "$T" for T in (Float64, ComplexF64) - # test_scalar(pinv, randn(T)) - # @test frule((ZeroTangent(), randn(T)), pinv, zero(T))[2] ≈ zero(T) - # @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T) - # end - # @testset "Vector{$T}" for T in (Float64, ComplexF64) - # test_frule(pinv, randn(T, 3), 0.0) - # test_frule(pinv, randn(T, 3), 0.0) + # mix types + test_rrule(cross, rand(3), rand(Float32, 3); rtol = 1.0e-7, atol = 1.0e-7) + end + @testset "pinv" begin + @testset "$T" for T in (Float64, ComplexF64) + test_scalar(pinv, randn(T)) + @test frule((ZeroTangent(), randn(T)), pinv, zero(T))[2] ≈ zero(T) + @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T) + end + @testset "Vector{$T}" for T in (Float64, ComplexF64) + test_frule(pinv, randn(T, 3), 0.0) + test_frule(pinv, randn(T, 3), 0.0) - # # Checking types. TODO do we still need this? - # x = randn(T, 3) - # ẋ = randn(T, 3) - # Δy = copyto!(similar(pinv(x)), randn(T, 3)) - # @test frule((ZeroTangent(), ẋ), pinv, x)[2] isa typeof(pinv(x)) - # @test rrule(pinv, x)[2](Δy)[2] isa typeof(x) - # end + # Checking types. TODO do we still need this? + x = randn(T, 3) + ẋ = randn(T, 3) + Δy = copyto!(similar(pinv(x)), randn(T, 3)) + @test frule((ZeroTangent(), ẋ), pinv, x)[2] isa typeof(pinv(x)) + @test rrule(pinv, x)[2](Δy)[2] isa typeof(x) + end - # @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) - # test_frule(pinv, F(randn(T, 3))) - # test_rrule(pinv, F(randn(T, 3))) + @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) + test_frule(pinv, F(randn(T, 3))) + test_rrule(pinv, F(randn(T, 3))) - # # Check types. - # # TODO: Do we need this still? - # x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3)) - # y = pinv(x) - # Δy = copyto!(similar(y), randn(T, 3)) + # Check types. + # TODO: Do we need this still? + x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3)) + y = pinv(x) + Δy = copyto!(similar(y), randn(T, 3)) - # y_fwd, ∂y_fwd = frule((ZeroTangent(), ẋ), pinv, x) - # @test y_fwd isa typeof(y) - # @test ∂y_fwd isa typeof(y) + y_fwd, ∂y_fwd = frule((ZeroTangent(), ẋ), pinv, x) + @test y_fwd isa typeof(y) + @test ∂y_fwd isa typeof(y) - # y_rev, back = rrule(pinv, x) - # @test y_rev isa typeof(y) - # @test back(Δy)[2] isa typeof(x) - # end - # @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), - # m in 1:3, - # n in 1:3 + y_rev, back = rrule(pinv, x) + @test y_rev isa typeof(y) + @test back(Δy)[2] isa typeof(x) + end + @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), + m in 1:3, + n in 1:3 - # test_frule(pinv, randn(T, m, n)) - # test_rrule(pinv, randn(T, m, n)) - # end - # end - # @testset "$f" for f in (det, logdet) - # @testset "$f(::$T)" for T in (Float64, ComplexF64) - # b = (f === logdet && T <: Real) ? abs(randn(T)) : randn(T) - # test_scalar(f, b) - # end - # @testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64) - # B = generate_well_conditioned_matrix(T, 4) - # if f === logdet && float(T) <: Float32 - # test_frule(f, B; atol=1e-5, rtol=1e-5) - # test_rrule(f, B; atol=1e-5, rtol=1e-5) - # else - # test_frule(f, B) - # test_rrule(f, B) - # end - # end - # @testset "$f(complex determinant)" begin - # B = randn(ComplexF64, 4, 4) - # U = exp(B - B') - # test_frule(f, U) - # test_rrule(f, U) - # end - # @testset "gpu" begin - # @gpu_broken test_rrule(f, reshape(1:9, 3, 3)+I*pi) - # end - # end - # @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64) - # B = randn(T, 4, 4) - # test_frule(logabsdet, B) - # test_rrule(logabsdet, B) - # # test for opposite sign of determinant - # test_frule(logabsdet, -B) - # test_rrule(logabsdet, -B) - # end - # @testset "tr" begin - # @gpu test_frule(tr, randn(4, 4)) - # @gpu test_rrule(tr, randn(4, 4)) - # end - # @testset "sylvester" begin - # @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3) - # A = randn(T, m, m) - # B = randn(T, n, n) - # C = randn(T, m, n) - # test_frule(sylvester, A, B, C) - # test_rrule(sylvester, A, B, C) - # end - # end - # @testset "lyap" begin - # n = 3 - # @testset "Float64" for T in (Float64, ComplexF64) - # A = randn(T, n, n) - # C = randn(T, n, n) - # test_frule(lyap, A, C) - # test_rrule(lyap, A, C) - # end - # end + test_frule(pinv, randn(T, m, n)) + test_rrule(pinv, randn(T, m, n)) + end + end + @testset "$f" for f in (det, logdet) + @testset "$f(::$T)" for T in (Float64, ComplexF64) + b = (f === logdet && T <: Real) ? abs(randn(T)) : randn(T) + test_scalar(f, b) + end + @testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64) + B = generate_well_conditioned_matrix(T, 4) + if f === logdet && float(T) <: Float32 + test_frule(f, B; atol=1e-5, rtol=1e-5) + test_rrule(f, B; atol=1e-5, rtol=1e-5) + else + test_frule(f, B) + test_rrule(f, B) + end + end + @testset "$f(complex determinant)" begin + B = randn(ComplexF64, 4, 4) + U = exp(B - B') + test_frule(f, U) + test_rrule(f, U) + end + @testset "gpu" begin + @gpu_broken test_rrule(f, reshape(1:9, 3, 3)+I*pi) + end + end + @testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64) + B = randn(T, 4, 4) + test_frule(logabsdet, B) + test_rrule(logabsdet, B) + # test for opposite sign of determinant + test_frule(logabsdet, -B) + test_rrule(logabsdet, -B) + end + @testset "tr" begin + @gpu test_frule(tr, randn(4, 4)) + @gpu test_rrule(tr, randn(4, 4)) + end + @testset "sylvester" begin + @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3) + A = randn(T, m, m) + B = randn(T, n, n) + C = randn(T, m, n) + test_frule(sylvester, A, B, C) + test_rrule(sylvester, A, B, C) + end + end + @testset "lyap" begin + n = 3 + @testset "Float64" for T in (Float64, ComplexF64) + A = randn(T, n, n) + C = randn(T, n, n) + test_frule(lyap, A, C) + test_rrule(lyap, A, C) + end + end VERSION ≥ v"1.9.0" && @testset "kron" begin @testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64) @testset "frule" begin diff --git a/test/runtests.jl b/test/runtests.jl index 5853ba1ca..768f7c208 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,7 +50,7 @@ end include("test_helpers.jl") # This can't be skipped println() - # test_method_tables() # Check the global method tables are consistent + test_method_tables() # Check the global method tables are consistent # Each file puts all tests inside one or more @testset blocks include_test("rulesets/Base/CoreLogging.jl") @@ -64,30 +64,30 @@ end include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") - # include_test("unzipped.jl") # used primarily for broadcast + include_test("unzipped.jl") # used primarily for broadcast - # println() + println() - # include_test("rulesets/Statistics/statistics.jl") + include_test("rulesets/Statistics/statistics.jl") - # println() + println() include_test("rulesets/LinearAlgebra/dense.jl") - # include_test("rulesets/LinearAlgebra/norm.jl") - # include_test("rulesets/LinearAlgebra/matfun.jl") - # include_test("rulesets/LinearAlgebra/structured.jl") - # include_test("rulesets/LinearAlgebra/symmetric.jl") - # include_test("rulesets/LinearAlgebra/factorization.jl") - # include_test("rulesets/LinearAlgebra/blas.jl") - # include_test("rulesets/LinearAlgebra/lapack.jl") - # include_test("rulesets/LinearAlgebra/uniformscaling.jl") + include_test("rulesets/LinearAlgebra/norm.jl") + include_test("rulesets/LinearAlgebra/matfun.jl") + include_test("rulesets/LinearAlgebra/structured.jl") + include_test("rulesets/LinearAlgebra/symmetric.jl") + include_test("rulesets/LinearAlgebra/factorization.jl") + include_test("rulesets/LinearAlgebra/blas.jl") + include_test("rulesets/LinearAlgebra/lapack.jl") + include_test("rulesets/LinearAlgebra/uniformscaling.jl") - # println() + println() - # include_test("rulesets/SparseArrays/sparsematrix.jl") + include_test("rulesets/SparseArrays/sparsematrix.jl") - # println() + println() - # include_test("rulesets/Random/random.jl") - # println() + include_test("rulesets/Random/random.jl") + println() end From f104172d019a4b2d7de5f9e04598098e1b48170e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:53:28 +0100 Subject: [PATCH 16/16] Improve version bound Co-authored-by: David Widmann --- src/rulesets/LinearAlgebra/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index f6e713a13..11f0c202b 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -399,7 +399,7 @@ end ##### `kron` ##### -@static if VERSION ≥ v"1.9.0" +@static if VERSION ≥ v"1.9.0-DEV.1267" function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number}) return kron(x, y), kron(Δx, y) + kron(x, Δy) end