diff --git a/src/api/instancenorm.jl b/src/api/instancenorm.jl index 58db6e63..15878552 100644 --- a/src/api/instancenorm.jl +++ b/src/api/instancenorm.jl @@ -43,8 +43,8 @@ end function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, training::TrainingType, - σ::F=identity, momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F} + rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, + momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F} assert_valid_instancenorm_arguments(x) y, rμₙ, rσ²ₙ = instancenorm_impl( diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 83d82d2c..9afc4cde 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -132,10 +132,10 @@ CRC.@non_differentiable get_norm_reshape_dims(::Any...) # Entry Points ## InstanceNorm -function instancenorm(x::AbstractArray{xT, N}, rμ::Optional{<:AbstractVector}, - rσ²::Optional{<:AbstractVector}, γ::Optional{<:AbstractVector}, - β::Optional{<:AbstractVector}, training::StaticBool, - momentum, epsilon, act::F) where {xT, N, F} +function instancenorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::StaticBool, + act::F, momentum, epsilon) where {xT, N, F} y, rμₙ, rσ²ₙ = normalization( x, rμ, rσ², γ, β, instancenorm_reduce_dims(x), training, momentum, epsilon, act) return y, safe_vec(rμₙ), safe_vec(rσ²ₙ) diff --git a/test/normalization/instancenorm_tests.jl b/test/normalization/instancenorm_tests.jl index 4e12c197..848b25ba 100644 --- a/test/normalization/instancenorm_tests.jl +++ b/test/normalization/instancenorm_tests.jl @@ -53,9 +53,10 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) - lfn = (x, sc, b, rm, rv, act, ϵ) -> sum(first(instancenorm( - x, sc, b, rm, rv, Val(true), act, T(0.1), ϵ))) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, rm, rv, act, epsilon)) isa Any + lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( + x, sc, b, rm, rv, Val(true), act, m, ϵ))) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any end @test y isa aType{T, length(sz)}