Skip to content

Commit

Permalink
refactor: restrict Reactant to 0.2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 4, 2024
1 parent 06f1604 commit 9ded737
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.3"
Reactant = "0.2.4"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Optimisers = "0.3.3"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.1"
Reactant = "0.2.4"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2"
Reactant = "0.2.4"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
29 changes: 11 additions & 18 deletions lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,8 @@ module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type
using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, TracedRNumber

@static if isdefined(Reactant, :ConcreteRNumber)
const ConcreteRType = Union{ConcreteRArray, Reactant.ConcreteRNumber}
const ReactantType = Union{
RArray, TracedRArray, TracedRNumber, Reactant.ConcreteRNumber
}
else
const ConcreteRType = ConcreteRArray
const ReactantType = Union{RArray, TracedRArray, TracedRNumber}
end
using Reactant: Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRArray,
TracedRNumber

MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true
MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true
Expand All @@ -23,7 +14,7 @@ function MLDataDevices.default_device_rng(::ReactantDevice)
end

# Query Device from Array
function Internal.get_device(x::ConcreteRType)
function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray})
client = XLA.client(x.data)
device = XLA.device(x.data)
return ReactantDevice(client, device)
Expand All @@ -33,24 +24,26 @@ function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end

Internal.get_device_type(::ReactantType) = ReactantDevice
function Internal.get_device_type(
::Union{TracedRArray, TracedRNumber, ConcreteRArray, ConcreteRNumber})
return ReactantDevice
end

# unsafe_free!
Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing

# Device Transfer
function Adapt.adapt_storage(dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
function Adapt.adapt_storage(
dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive})
client = dev.client === missing ? XLA.default_backend[] : dev.client
device = dev.device === missing ?
XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device
return ConcreteRArray{eltype(x), ndims(x)}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x))
device = dev.device === missing ? nothing : dev.device
return ConcreteRArray(x; client, device)
end

end

0 comments on commit 9ded737

Please sign in to comment.