Skip to content

Commit

Permalink
feat: support passing in device and client to XLA
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 3, 2024
1 parent 409eda2 commit 57de396
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.2"
version = "1.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
40 changes: 35 additions & 5 deletions lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice
using Reactant: Reactant, RArray
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type
using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, TracedRNumber

@static if isdefined(Reactant, :ConcreteRNumber)
const ConcreteRType = Union{ConcreteRArray, Reactant.ConcreteRNumber}
const ReactantType = Union{
RArray, TracedRArray, TracedRNumber, Reactant.ConcreteRNumber
}
else
const ConcreteRType = ConcreteRArray
const ReactantType = Union{RArray, TracedRArray, TracedRNumber}
end

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true
Expand All @@ -13,14 +23,34 @@ function MLDataDevices.default_device_rng(::XLADevice)
end

# Query Device from Array
Internal.get_device(::RArray) = XLADevice()
function Internal.get_device(x::ConcreteRType)
client = XLA.client(x.data)
device = XLA.device(x.data)
return XLADevice(client, device)
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end

Internal.get_device_type(::RArray) = XLADevice
Internal.get_device_type(::ReactantType) = XLADevice

# unsafe_free!
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing

# Device Transfer
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x)
function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive})
client = dev.client === missing ? XLA.default_backend[] : dev.client
device = dev.device === missing ?
XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device
return ConcreteRArray{eltype(x), ndims(x)}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x))
end

end
17 changes: 12 additions & 5 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end

# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end
@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice
client::C = missing
device::D = missing
end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end
Expand Down Expand Up @@ -189,20 +191,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU.
cpu_device() = CPUDevice()

"""
xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice}
xla_device(;
force::Bool=false, client=missing, device=missing
) -> Union{XLADevice, CPUDevice}
Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.
`client` and `device` are used to specify the client and particular device to use. If not
specified, then the default client and index are used.
!!! danger
This is an experimental feature and might change without deprecations
"""
function xla_device(; force::Bool=false)
function xla_device(; force::Bool=false, client=missing, device=missing)
msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \
this function. Defaulting to CPU."
if loaded(XLADevice)
functional(XLADevice) && return XLADevice()
functional(XLADevice) && return XLADevice(client, device)
msg = "`XLADevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
Expand Down

0 comments on commit 57de396

Please sign in to comment.