diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 68d43257bb..640a309c7f 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.2" +version = "1.5.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index 3abc8fca2c..a62f87aa10 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -1,8 +1,18 @@ module MLDataDevicesReactantExt using Adapt: Adapt -using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice -using Reactant: Reactant, RArray +using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type +using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, TracedRNumber + +@static if isdefined(Reactant, :ConcreteRNumber) + const ConcreteRType = Union{ConcreteRArray, Reactant.ConcreteRNumber} + const ReactantType = Union{ + RArray, TracedRArray, TracedRNumber, Reactant.ConcreteRNumber + } +else + const ConcreteRType = ConcreteRArray + const ReactantType = Union{RArray, TracedRArray, TracedRNumber} +end MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true @@ -13,14 +23,34 @@ function MLDataDevices.default_device_rng(::XLADevice) end # Query Device from Array -Internal.get_device(::RArray) = XLADevice() +function Internal.get_device(x::ConcreteRType) + client = XLA.client(x.data) + device = XLA.device(x.data) + return XLADevice(client, device) +end + +function Internal.get_device(::Union{TracedRArray, TracedRNumber}) + error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") +end -Internal.get_device_type(::RArray) = XLADevice +Internal.get_device_type(::ReactantType) = XLADevice # unsafe_free! Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing # Device Transfer -Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x) +function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) + @warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \ + transfer this via CPU." maxlog=1 + return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x)) +end + +function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive}) + client = dev.client === missing ? XLA.default_backend[] : dev.client + device = dev.device === missing ? + XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device + return ConcreteRArray{eltype(x), ndims(x)}( + XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x)) +end end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 6440ddbe74..a3e61aab6e 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -9,8 +9,10 @@ end struct MetalDevice <: AbstractGPUDevice end struct oneAPIDevice <: AbstractGPUDevice end -# TODO: Later we might want to add the client field here? -struct XLADevice <: AbstractAcceleratorDevice end +@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice + client::C = missing + device::D = missing +end # Fallback for when we don't know the device type struct UnknownDevice <: AbstractDevice end @@ -189,20 +191,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU. cpu_device() = CPUDevice() """ - xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice} + xla_device(; + force::Bool=false, client=missing, device=missing + ) -> Union{XLADevice, CPUDevice} Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`. Falls back to `CPUDevice` if `force` is `false`. +`client` and `device` are used to specify the client and particular device to use. If not +specified, then the default client and index are used. + !!! danger This is an experimental feature and might change without deprecations """ -function xla_device(; force::Bool=false) +function xla_device(; force::Bool=false, client=missing, device=missing) msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \ this function. Defaulting to CPU." if loaded(XLADevice) - functional(XLADevice) && return XLADevice() + functional(XLADevice) && return XLADevice(client, device) msg = "`XLADevice` is loaded but not functional. Defaulting to CPU." end force && throw(Internal.DeviceSelectionException("XLA"))