Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: use KA for CPU operations
Browse files Browse the repository at this point in the history
[skip tests]
  • Loading branch information
avik-pal committed Aug 20, 2024
1 parent 4bf366c commit b57f2a1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 502 deletions.
233 changes: 2 additions & 231 deletions src/impl/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,135 +73,7 @@ end
end

function batchnorm_affine_normalize_internal!(
y::AbstractArray{<:Number, 3}, opmode::LoopedArrayOp, act::F,
x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector,
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector},
ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F}
N = size(y, 2)
γ′ = γ′ === nothing ?
similar(x, promote_type(Utils.eltype(γ), Utils.eltype(σ²), Utils.eltype(ϵ)), N) :
γ′
β′ = similar(x, promote_type(Utils.eltype(β), Utils.eltype(σ²), Utils.eltype(ϵ)), N)

compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ)

if Utils.known(Traits.fuse_cpu_activation(act))
apply_batchnorm_scale_bias_act_cpu!(y, γ′, β′, x, act)
else
apply_batchnorm_scale_bias_cpu!(y, γ′, β′, x)
activation!(y, opmode, act, y)
end

return
end

function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ)
if γ === nothing && β === nothing
@simd ivdep for J in indices((γ′, β′, μ, σ²))
@fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ))
@fastmath @inbounds β′[J] = -μ[J] * γ′[J]
end
else
@simd ivdep for J in indices((γ′, β′, γ, β, μ, σ²))
@fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ)
@fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J]
end
end
end

function apply_batchnorm_scale_bias_act_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F}
if size(y, 1) == 1
apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
end
end

@inline function apply_batchnorm_scale_bias_act_2d_serial_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F}
for K in indices((x, y), 3)
@simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1))
@fastmath @inbounds y[1, J, K] = σ(x[1, J, K] * γ′[J] + β′[J])
end
end
end

@inline function apply_batchnorm_scale_bias_act_3d_threaded_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F}
@batch for K in indices((x, y), 3)
for J in indices((x, y, γ′, β′), (2, 2, 1, 1))
@simd ivdep for I in indices((x, y), 1)
@fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J])
end
end
end
end

@inline function apply_batchnorm_scale_bias_act_3d_serial_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3}, σ::F) where {F}
for K in indices((x, y), 3)
for J in indices((x, y, γ′, β′), (2, 2, 1, 1))
@simd ivdep for I in indices((x, y), 1)
@fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J])
end
end
end
end

Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu!

function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3})
if size(y, 1) == 1
apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
end
end

@inline function apply_batchnorm_scale_bias_2d_serial_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3})
for K in indices((x, y), 3)
@simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1))
@fastmath @inbounds y[1, J, K] = x[1, J, K] * γ′[J] + β′[J]
end
end
end

@inline function apply_batchnorm_scale_bias_3d_threaded_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3})
@batch for K in indices((x, y), 3)
for J in indices((x, y, γ′, β′), (2, 2, 1, 1))
@simd ivdep for I in indices((x, y), 1)
@fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J]
end
end
end
end

@inline function apply_batchnorm_scale_bias_3d_serial_cpu!(
y::AbstractArray{<:Number, 3}, γ′::AbstractVector,
β′::AbstractVector, x::AbstractArray{<:Number, 3})
for K in indices((x, y), 3)
for J in indices((x, y, γ′, β′), (2, 2, 1, 1))
@simd ivdep for I in indices((x, y), 1)
@fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J]
end
end
end
end

Utils.@enzyme_reverse_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu!

function batchnorm_affine_normalize_internal!(
y::AbstractArray{<:Number, 3}, ::GPUBroadcastOp, act::F,
y::AbstractArray{<:Number, 3}, ::Union{GPUBroadcastOp, LoopedArrayOp}, act::F,
x::AbstractArray{<:Number, 3}, μ::AbstractVector, σ²::AbstractVector,
γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector},
ϵ::Real, γ′::Optional{<:AbstractVector}=nothing) where {F}
Expand Down Expand Up @@ -280,107 +152,6 @@ function CRC.rrule(
return z, ∇batchnorm_affine_normalize_internal
end

function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{<:Number, 3},
x::AbstractArray{<:Number, 3}, μ::AbstractVector,
σ²::AbstractVector, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector)
∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²)
∂γ = γ === nothing ? nothing : similar(γ)
∂β = β === nothing ? nothing : similar(β)

∇batchnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, ∂y, x, μ, σ², γ, ϵ, γ′)

∂γ = γ === nothing ? ∂∅ : ∂γ
∂β = β === nothing ? ∂∅ : ∂β

return ∂x, ∂μ, ∂σ², ∂γ, ∂β
end

function ∇batchnorm_affine_normalize_cpu!(
∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number},
∂σ²::AbstractVector{<:Number}, ::Nothing, ::Nothing,
∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3},
μ::AbstractVector, σ²::AbstractVector, ::Nothing, ϵ::Real, γ′::AbstractVector)
half = eltype(∂σ²)(0.5)

fill!(∂μ, 0)
fill!(∂σ², 0)

if size(∂y, 1) == 1
@fastmath @inbounds for K in indices(∂y, 3)
@simd for J in indices(∂y, 2)
idenom = γ′[J]
idenom² = idenom^2

= x[1, J, K] - μ[J]

∂x[1, J, K] = ∂y[1, J, K] * idenom
∂μ[J] -= ∂x[1, J, K]
∂σ²[J] -= ∂x[1, J, K] ** half * idenom²
end
end
else
@fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2)
idenom = γ′[J]
idenom² = idenom^2

@simd for I in indices(∂y, 1)
= x[I, J, K] - μ[J]

∂x[I, J, K] = ∂y[I, J, K] * idenom
∂μ[J] -= ∂x[I, J, K]
∂σ²[J] -= ∂x[I, J, K] ** half * idenom²
end
end
end
end

function ∇batchnorm_affine_normalize_cpu!(
∂x::AbstractArray{<:Number, 3}, ∂μ::AbstractVector{<:Number},
∂σ²::AbstractVector{<:Number}, ∂γ::AbstractVector{<:Number},
∂β::AbstractVector{<:Number}, ∂y::AbstractArray{<:Number, 3},
x::AbstractArray{<:Number, 3}, μ::AbstractVector,
σ²::AbstractVector, γ::AbstractVector, ϵ::Real, γ′::AbstractVector)
half = eltype(∂σ²)(0.5)

fill!(∂μ, 0)
fill!(∂σ², 0)
fill!(∂γ, 0)
fill!(∂β, 0)

if size(∂y, 1) == 1
@fastmath @inbounds for K in indices(∂y, 3)
@simd for J in indices(∂y, 2)
idenom = inv(sqrt(σ²[J] + ϵ))
idenom² = idenom^2

= x[1, J, K] - μ[J]

∂x[1, J, K] = ∂y[1, J, K] * γ′[J]
∂μ[J] -= ∂x[1, J, K]
∂σ²[J] -= ∂x[1, J, K] ** half * idenom²
∂γ[J] += ∂y[1, J, K] ** idenom
∂β[J] += ∂y[1, J, K]
end
end
else
@fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2)
idenom = inv(sqrt(σ²[J] + ϵ))
idenom² = idenom^2

@simd for I in indices(∂y, 1)
= x[I, J, K] - μ[J]

∂x[I, J, K] = ∂y[I, J, K] * γ′[J]
∂μ[J] -= ∂x[I, J, K]
∂σ²[J] -= ∂x[I, J, K] ** half * idenom²
∂γ[J] += ∂y[I, J, K] ** idenom
∂β[J] += ∂y[I, J, K]
end
end
end
end

function ∇batchnorm_affine_normalize(
opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{<:Number, 3},
x::AbstractArray{<:Number, 3}, μ::AbstractVector,
Expand All @@ -401,7 +172,7 @@ end

function ∇batchnorm_affine_normalize!(
∂x::AbstractArray{<:Number, 3}, ∂σ²::AbstractArray{<:Number, 3},
∂γ::Optional{<:AbstractArray{<:Number, 3}}, ::GPUBroadcastOp,
∂γ::Optional{<:AbstractArray{<:Number, 3}}, ::Union{GPUBroadcastOp, LoopedArrayOp},
∂y::AbstractArray{<:Number, 3}, x::AbstractArray{<:Number, 3}, μ::AbstractVector,
σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, γ′::AbstractVector)
backend = KA.get_backend(∂x)
Expand Down
Loading

0 comments on commit b57f2a1

Please sign in to comment.