From 0be16f9c9716396fd56678d8e18dbde377be6421 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 31 Dec 2024 21:37:41 -0500 Subject: [PATCH] feat: handle RNGs in layers correctly --- Project.toml | 4 +-- lib/MLDataDevices/Project.toml | 2 +- .../ext/MLDataDevicesReactantExt.jl | 7 ++-- lib/MLDataDevices/test/xla_tests.jl | 3 +- test/reactant/layer_tests.jl | 34 +++++++++++++++++++ 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 227398b6d..a08546d73 100644 --- a/Project.toml +++ b/Project.toml @@ -100,7 +100,7 @@ LinearAlgebra = "1.10" LossFunctions = "0.11.1, 1" LuxCore = "1.2" LuxLib = "1.3.7" -MLDataDevices = "1.6" +MLDataDevices = "1.6.6" MLUtils = "0.4.4" MPI = "0.20.19" MacroTools = "0.5.13" @@ -110,7 +110,7 @@ NNlib = "0.9.26" Optimisers = "0.4.1" Preferences = "1.4.3" Random = "1.10" -Reactant = "0.2.12" +Reactant = "0.2.13" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 054f1a462..87f9a2650 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -66,7 +66,7 @@ Metal = "1" OneHotArrays = "0.2.5" Preferences = "1.4" Random = "1.10" -Reactant = "0.2.12" +Reactant = "0.2.13" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index 9c48fd744..9cc1f082c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -2,6 +2,7 @@ module MLDataDevicesReactantExt using Adapt: Adapt using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type +using Random: Random using Reactant: Reactant, XLA, ConcreteRArray, ConcreteRNumber, TracedRArray, TracedRNumber @@ -15,9 +16,7 @@ end # Query Device from Array function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray}) - client = XLA.client(x.data) - device = XLA.device(x.data) - return ReactantDevice(client, device) + return ReactantDevice(XLA.client(x.data), XLA.device(x.data)) end function Internal.get_device(::Union{TracedRArray, TracedRNumber}) @@ -54,4 +53,6 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::ConcreteRArray) return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x)) end +Adapt.adapt_storage(::CPUDevice, ::Reactant.ConcreteRNG) = Random.default_rng() + end diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index 30377c828..bf39be0c7 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -39,7 +39,8 @@ using FillArrays, Zygote # Extensions device = reactant_device() aType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRArray : Array - rngType = Random.AbstractRNG + rngType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRNG : + Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa ReactantDevice diff --git a/test/reactant/layer_tests.jl b/test/reactant/layer_tests.jl index 8130691cb..e0e0fb526 100644 --- a/test/reactant/layer_tests.jl +++ b/test/reactant/layer_tests.jl @@ -63,3 +63,37 @@ end end end end + +@testitem "Dropout Layers" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin + using Reactant, Lux, Random + + @testset "$(mode)" for (mode, atype, dev, ongpu) in MODES + if mode == "amdgpu" + @warn "Skipping AMDGPU tests for Reactant" + continue + end + + dev = reactant_device(; force=true) + + if ongpu + Reactant.set_default_backend("gpu") + else + Reactant.set_default_backend("cpu") + end + + @testset for layer in (AlphaDropout, Dropout, VariationalHiddenDropout) + model = layer(0.5f0) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + x = randn(Float32, 10, 10) |> dev + + @test st.rng isa Reactant.ConcreteRNG + + hlo = @code_hlo model(x, ps, st) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + y, st2 = @jit model(x, ps, st) + @test st2.rng isa Reactant.ConcreteRNG + @test st.rng.seed != st2.rng.seed + end + end +end