From c5dbe030af390599848830ff43a5dffc04be69e2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 6 May 2022 23:37:47 -0400 Subject: [PATCH] `normalize` not just vectors (#602) * normalise arrays not just vectors * versions --- Project.toml | 2 +- src/rulesets/LinearAlgebra/norm.jl | 4 ++-- test/rulesets/LinearAlgebra/norm.jl | 12 +++++++++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 05fde5b0b..881f63716 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.4" +version = "1.29.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 2cc69355a..fbe23761f 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -257,7 +257,7 @@ end ##### `normalize` ##### -function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) +function rrule(::typeof(normalize), x::AbstractArray{<:Number}, p::Real) nrm, inner_pullback = rrule(norm, x, p) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) @@ -273,7 +273,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real) return y, normalize_pullback end -function rrule(::typeof(normalize), x::AbstractVector{<:Number}) +function rrule(::typeof(normalize), x::AbstractArray{<:Number}) nrm = LinearAlgebra.norm2(x) Ty = typeof(first(x) / nrm) y = copyto!(similar(x, Ty), x) diff --git a/test/rulesets/LinearAlgebra/norm.jl b/test/rulesets/LinearAlgebra/norm.jl index 3551583fa..8c1eda6f3 100644 --- a/test/rulesets/LinearAlgebra/norm.jl +++ b/test/rulesets/LinearAlgebra/norm.jl @@ -182,15 +182,21 @@ end # =================================== @testset "normalize" begin - @testset "x::Vector{$T}" for T in (Float64, ComplexF64) + @testset "x::Array{$T}" for T in (Float64, ComplexF64) x = randn(T, 3) test_rrule(normalize, x) @test rrule(normalize, x)[2](ZeroTangent()) === (NoTangent(), ZeroTangent()) + + test_rrule(normalize, rand(T, 3, 4)) + test_rrule(normalize, adjoint(rand(T, 5))) end - @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64), - p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable + @testset "x::Array{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5) + # skip p=0, since FD is unstable x = randn(T, 3) test_rrule(normalize, x, p) @test rrule(normalize, x, p)[2](ZeroTangent()) === (NoTangent(), ZeroTangent(), ZeroTangent()) + + test_rrule(normalize, rand(T, 3, 4), p) + test_rrule(normalize, adjoint(rand(T, 5)), p) end end