Skip to content

Commit

Permalink
fix: soa/aos handling for multigate
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 10, 2024
1 parent 639f813 commit e6f6993
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
return ArrayInterface.aos_to_soa(reverse(x; dims))
end

# multigate: avoid soa formation
@inline function Lux._gate(x::TrackedArray{T, R, 1}, h::Int, n::Int) where {T, R}
return x[Lux._gate( h, n)]
end
@inline function Lux._gate(x::AbstractVector{<:TrackedReal}, h::Int, n::Int)
return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n)))
end
@inline function Lux._gate(x::TrackedArray{T, R, 2}, h::Int, n::Int) where {T, R}
return x[Lux._gate(h, n), :]
end
@inline function Lux._gate(x::AbstractMatrix{<:TrackedReal}, h::Int, n::Int)
return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n), :))
end

@inline function Lux.__convert_eltype(::Type{T}, x::AbstractArray{<:TrackedReal}) where {T}
@warn "`Lux.__convert_eltype` doesn't support converting element types of ReverseDiff \
`TrackedReal` arrays. Currently this is a no-op." maxlog=1
Expand Down

0 comments on commit e6f6993

Please sign in to comment.