Skip to content

Commit

Permalink
more robust BroadcastStyle handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Dec 29, 2023
1 parent 06aba31 commit 8a6bc82
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,9 @@ end
Adapt.adapt_structure(T, x::OneHotArray) = OneHotArray(adapt(T, _indices(x)), x.nlabels)

function Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
S = Base.BroadcastStyle(T)
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
(typeof(S).name.wrapper){var"N+1"}()
typeof(S)(Val{var"N+1"}())
end

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
Expand Down

0 comments on commit 8a6bc82

Please sign in to comment.