From 6acef745915aa6c7ac73fb109ad451c688056328 Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Tue, 9 Aug 2022 14:59:08 -0400 Subject: [PATCH 1/2] add array-quantity multiplication/divison rules --- Project.toml | 1 + src/UnitfulChainRules.jl | 5 ++- src/arraymath.jl | 66 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 src/arraymath.jl diff --git a/Project.toml b/Project.toml index aaa7ef2..c6bb3ec 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [compat] diff --git a/src/UnitfulChainRules.jl b/src/UnitfulChainRules.jl index 6a610fe..7419fc2 100644 --- a/src/UnitfulChainRules.jl +++ b/src/UnitfulChainRules.jl @@ -2,8 +2,9 @@ module UnitfulChainRules using Unitful using Unitful: Quantity, Units, NoDims, FreeUnits -using ChainRulesCore: NoTangent, @scalar_rule, @thunk +using ChainRulesCore import ChainRulesCore: rrule, frule, ProjectTo +using LinearAlgebra const REALCOMPLEX = Union{Real, Complex} @@ -18,4 +19,6 @@ include("./trig.jl") # sin, cos, tan, etc for degrees include("./math.jl") # other math +include("./arraymath.jl") # Simple scalar-array math + end # module diff --git a/src/arraymath.jl b/src/arraymath.jl new file mode 100644 index 0000000..4b13325 --- /dev/null +++ b/src/arraymath.jl @@ -0,0 +1,66 @@ +const CommutativeMulQuantity = Quantity{T,D,U} where {T<:Union{Real,Complex}, D, U} +const CommMulVal = Union{Real, Complex, CommutativeMulQuantity} + +# Reference: https://github.com/JuliaDiff/ChainRules.jl/blob/148fa8875725a19cf658405609fa1a56671d0cbd/src/rulesets/Base/arraymath.jl + +# Defines *, / for the pairs where: +# 1. The scalar is a commutative/mul quantity and the array is real, complex, or a comm/mul quantity +# 2. The scalar is a commutative/mul number and the array is a comm/mul quantity +# We have to be careful defining this so that we always have a Quantity in the signature, otherwise +# we overwrite methods from ChainRules.jl +for (s_type,a_type) in ( + (:CommutativeMulQuantity, :(<:CommMulVal)), + (:(Union{Real,Complex}), :(<:CommutativeMulQuantity)) + ) + @eval function rrule( + ::typeof(*), A::$(s_type), B::AbstractArray{$(a_type)} + ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function times_pullback(ȳ) + Ȳ = unthunk(ȳ) + return ( + NoTangent(), + @thunk(project_A(dot(Ȳ, B)')), + InplaceableThunk( + X̄ -> mul!(X̄, conj(A), Ȳ, true, true), + @thunk(project_B(A' * Ȳ)), + ) + ) + end + return A * B, times_pullback + end + + @eval function rrule( + ::typeof(*), B::AbstractArray{$(a_type)}, A::$(s_type) + ) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function times_pullback(ȳ) + Ȳ = unthunk(ȳ) + return ( + NoTangent(), + InplaceableThunk( + X̄ -> mul!(X̄, conj(A), Ȳ, true, true), + @thunk(project_B(A' * Ȳ)), + ), + @thunk(project_A(dot(Ȳ, B)')), + ) + end + return A * B, times_pullback + end + + @eval function rrule(::typeof(/), A::AbstractArray{$(a_type)}, b::$(s_type)) + Y = A/b + function slash_pullback_scalar(ȳ) + Ȳ = unthunk(ȳ) + Athunk = InplaceableThunk( + dA -> dA .+= Ȳ ./ conj(b), + @thunk(Ȳ / conj(b)), + ) + bthunk = @thunk(-dot(A,Ȳ) / conj(b^2)) + return (NoTangent(), Athunk, bthunk) + end + return Y, slash_pullback_scalar + end +end From 53f588412404606ada5a5bd024c4afdf033ad446 Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Tue, 9 Aug 2022 15:20:16 -0400 Subject: [PATCH 2/2] add basic array-scalar math tests --- test/arraymath.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ 2 files changed, 50 insertions(+) create mode 100644 test/arraymath.jl diff --git a/test/arraymath.jl b/test/arraymath.jl new file mode 100644 index 0000000..d3c1e3a --- /dev/null +++ b/test/arraymath.jl @@ -0,0 +1,46 @@ +using Unitful +using UnitfulChainRules + +using Zygote + +using Random +rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451) + +@testset "Array-Scalar Multiplication" begin + for (a_unit, s_unit) in ((1.0,oneunit(1.0u"m")), (oneunit(1.0u"m"), 1.0)) + A = randn(rng, 5) * a_unit + s = randn(rng) * s_unit + + @testset "A * s ($a_unit, $s_unit)" begin + Ω, pb = Zygote.pullback(*, A, s) + + @test Ω ≈ A * s + @test all(first(pb(one.(Ω))) .≈ s) + @test last(pb(one.(Ω))) ≈ sum(A) + end + + @testset "s * A ($s_unit, $a_unit)" begin + + Ω, pb = Zygote.pullback(*, s, A) + + @test Ω ≈ s * A + @test all(last(pb(one.(Ω))) .≈ s) + @test first(pb(one.(Ω))) ≈ sum(A) + end + end +end + +@testset "Array-Scalar Division" begin + for (a_unit, s_unit) in ((1.0,oneunit(1.0u"m")), (oneunit(1.0u"m"), 1.0)) + @testset "($a_unit, $s_unit) division" begin + A = randn(rng, 5) * a_unit + s = randn(rng) * s_unit + + Ω, pb = Zygote.pullback(/, A, s) + + @test Ω ≈ A / s + @test all(first(pb(one.(Ω))) .≈ inv(s)) + @test last(pb(one.(Ω))) ≈ -sum(A)/s^2 + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 2fe38f6..cc277b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,4 +16,8 @@ end @safetestset "Math" begin include("./math.jl") +end + +@safetestset "Array Math" begin + include("./arraymath.jl") end \ No newline at end of file