Skip to content

Commit

Permalink
feat: handle RNGs in layers correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 1, 2025
1 parent ce9f77f commit 0be16f9
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ LinearAlgebra = "1.10"
LossFunctions = "0.11.1, 1"
LuxCore = "1.2"
LuxLib = "1.3.7"
MLDataDevices = "1.6"
MLDataDevices = "1.6.6"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand All @@ -110,7 +110,7 @@ NNlib = "0.9.26"
Optimisers = "0.4.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.12"
Reactant = "0.2.13"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Metal = "1"
OneHotArrays = "0.2.5"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2.12"
Reactant = "0.2.13"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
7 changes: 4 additions & 3 deletions lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type
using Random: Random
using Reactant: Reactant, XLA, ConcreteRArray, ConcreteRNumber, TracedRArray,
TracedRNumber

Expand All @@ -15,9 +16,7 @@ end

# Query Device from Array
function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray})
client = XLA.client(x.data)
device = XLA.device(x.data)
return ReactantDevice(client, device)
return ReactantDevice(XLA.client(x.data), XLA.device(x.data))
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
Expand Down Expand Up @@ -54,4 +53,6 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::ConcreteRArray)
return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x))
end

Adapt.adapt_storage(::CPUDevice, ::Reactant.ConcreteRNG) = Random.default_rng()

end
3 changes: 2 additions & 1 deletion lib/MLDataDevices/test/xla_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ using FillArrays, Zygote # Extensions

device = reactant_device()
aType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRArray : Array
rngType = Random.AbstractRNG
rngType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRNG :
Random.AbstractRNG

ps_xpu = ps |> device
@test get_device(ps_xpu) isa ReactantDevice
Expand Down
34 changes: 34 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,37 @@ end
end
end
end

@testitem "Dropout Layers" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux, Random

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

dev = reactant_device(; force=true)

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for layer in (AlphaDropout, Dropout, VariationalHiddenDropout)
model = layer(0.5f0)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
x = randn(Float32, 10, 10) |> dev

@test st.rng isa Reactant.ConcreteRNG

hlo = @code_hlo model(x, ps, st)
@test contains(repr(hlo), "stablehlo.rng_bit_generator")

y, st2 = @jit model(x, ps, st)
@test st2.rng isa Reactant.ConcreteRNG
@test st.rng.seed != st2.rng.seed
end
end
end

0 comments on commit 0be16f9

Please sign in to comment.