diff --git a/Project.toml b/Project.toml index 96ca11cd38..4308bbcc27 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ ChainRulesCore = "1" EnzymeCore = "0.8.8" Enzyme_jll = "0.0.168" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" +GPUArraysCore = "0.1.6, 0.2" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" diff --git a/ext/EnzymeGPUArraysCoreExt.jl b/ext/EnzymeGPUArraysCoreExt.jl index 9f1edbf8fd..12e3417338 100644 --- a/ext/EnzymeGPUArraysCoreExt.jl +++ b/ext/EnzymeGPUArraysCoreExt.jl @@ -3,14 +3,6 @@ module EnzymeGPUArraysCoreExt using GPUArraysCore using Enzyme -@inline function Enzyme.onehot(x::AbstractGPUArray) - Enzyme.onehot_internal(zerosetfn, x, 0, length(x)) -end - -@inline function Enzyme.onehot(x::AbstractGPUArray, start::Int, endl::Int) - Enzyme.onehot_internal(zerosetfn, x, start-1, endl-start+1) -end - function Enzyme.zerosetfn(x::AbstractGPUArray, i::Int) res = zero(x) @allowscalar @inbounds res[i] = 1 @@ -22,4 +14,25 @@ function Enzyme.zerosetfn!(x::AbstractGPUArray, i::Int, val) return end +@inline function Enzyme.onehot(x::AbstractGPUArray) + # Enzyme.onehot_internal(Enzyme.zerosetfn, x, 0, length(x)) + N = length(x) + ntuple(Val(N)) do i + Base.@_inline_meta + res = zero(x) + @allowscalar @inbounds res[i] = 1 + return res + end +end + +@inline function onehot(x::AbstractArray, start::Int, endl::Int) + # Enzyme.onehot_internal(Enzyme.zerosetfn, x, start-1, endl-start+1) + ntuple(Val(endl - start + 1)) do i + Base.@_inline_meta + res = zero(x) + @allowscalar @inbounds res[i + start - 1] = 1 + return res + end +end + end # module diff --git a/src/sugar.jl b/src/sugar.jl index 05a4668f0b..ab2069d492 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -12,7 +12,7 @@ function zerosetfn!(x, i::Int, val) nothing end -@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array} +@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:AbstractArray} ir = GPUCompiler.JuliaContext() do ctx Base.@_inline_meta diff --git a/test/Project.toml b/test/Project.toml index 818e0ac708..fbc6d754fe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/ext/jlarrays.jl b/test/ext/jlarrays.jl new file mode 100644 index 0000000000..760000ba75 --- /dev/null +++ b/test/ext/jlarrays.jl @@ -0,0 +1,11 @@ +using Enzyme, Test, JLArrays + +function jlres(x) + 2 * collect(x) +end + +@testset "JLArrays" begin + # TODO fix activity of jlarray + # Enzyme.jacobian(Forward, jlres, JLArray([3.0, 5.0])) + # Enzyme.jacobian(Reverse, jlres, JLArray([3.0, 5.0])) +end diff --git a/test/runtests.jl b/test/runtests.jl index c9cde02c28..7f6e4700f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3778,5 +3778,6 @@ include("ext/logexpfunctions.jl") include("ext/bfloat16s.jl") end +include("ext/jlarrays.jl") include("ext/sparsearrays.jl") include("ext/staticarrays.jl")