From e6f6993eaedbd500bcbb3a1a0d64605561f3b4d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jul 2024 20:55:40 -0700 Subject: [PATCH] fix: soa/aos handling for multigate --- ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 485e4ac495..74ff2b47c3 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -31,6 +31,20 @@ Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) return ArrayInterface.aos_to_soa(reverse(x; dims)) end +# multigate: avoid soa formation +@inline function Lux._gate(x::TrackedArray{T, R, 1}, h::Int, n::Int) where {T, R} + return x[Lux._gate( h, n)] +end +@inline function Lux._gate(x::AbstractVector{<:TrackedReal}, h::Int, n::Int) + return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n))) +end +@inline function Lux._gate(x::TrackedArray{T, R, 2}, h::Int, n::Int) where {T, R} + return x[Lux._gate(h, n), :] +end +@inline function Lux._gate(x::AbstractMatrix{<:TrackedReal}, h::Int, n::Int) + return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n), :)) +end + @inline function Lux.__convert_eltype(::Type{T}, x::AbstractArray{<:TrackedReal}) where {T} @warn "`Lux.__convert_eltype` doesn't support converting element types of ReverseDiff \ `TrackedReal` arrays. Currently this is a no-op." maxlog=1