Skip to content

Commit

Permalink
feat: support passing in device and client to XLA (#1020)
Browse files Browse the repository at this point in the history
* feat: support passing in device and client to XLA

* feat: add == dispatch for XLADevice

* refactor: rename XLADevice/xla_device to ReactantDevice/reactant_device

* refactor: restrict Reactant to 0.2.4
  • Loading branch information
avik-pal authored Nov 5, 2024
1 parent 8bfa628 commit dc2885f
Show file tree
Hide file tree
Showing 19 changed files with 140 additions and 88 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
LuxCore = "1"
LuxLib = "1.3.4"
MLDataDevices = "1.3"
LuxLib = "1.3.7"
MLDataDevices = "1.5"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand All @@ -98,7 +98,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.3"
Reactant = "0.2.4"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Optimisers = "0.3.3"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.1"
Reactant = "0.2.4"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/Accelerator_Support/MLDataDevices.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MLDataDevices.gpu_backend!
```@docs
MLDataDevices.cpu_device
MLDataDevices.gpu_device
MLDataDevices.xla_device
MLDataDevices.reactant_device
```

## Miscellaneous
Expand Down
6 changes: 3 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,21 @@ Run the following to access a device:
using Reactant, Lux
Reactant.set_default_backend("cpu") # default

const dev = xla_device()
const dev = reactant_device()
```

```julia [GPU Backend]
using Reactant, Lux
Reactant.set_default_backend("gpu")

const dev = xla_device()
const dev = reactant_device()
```

```julia [TPU Backend]
using Reactant, Lux
Reactant.set_default_backend("tpu")

const dev = xla_device()
const dev = reactant_device()
```

:::
8 changes: 4 additions & 4 deletions docs/src/manual/compiling_lux_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ x = randn(Float32, 2, 32)
y = x .^ 2
```

We will use [`xla_device`](@ref) similar to [`gpu_device`](@ref) to move the arrays to
We will use [`reactant_device`](@ref) similar to [`gpu_device`](@ref) to move the arrays to
`Reactant`.

```@example compile_lux_model
const xdev = xla_device()
const xdev = reactant_device()
x_ra = x |> xdev
y_ra = y |> xdev
Expand All @@ -66,7 +66,7 @@ pred_lux, _ = model(x, ps, Lux.testmode(st))

To run it using `XLA` we need to compile the model. We can do this using the
`Reactant.@compile` macro. Note that the inputs need to be moved to the device using
[`xla_device`](@ref) first.
[`reactant_device`](@ref) first.

```@example compile_lux_model
model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))
Expand Down Expand Up @@ -122,7 +122,7 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme)
Now that we saw the low-level API let's see how to train the model without any of this
boilerplate. Simply follow the following steps:

1. Create a device using `xla_device`. Remember to load `Reactant.jl` before doing this.
1. Create a device using `reactant_device`. Remember to load `Reactant.jl` before doing this.
2. Similar to other device functions move the model, parameters, states and data to the
device. Note that you might want to use [`DeviceIterator`](@ref) to move the data
loader to the device with an iterator.
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.6"
version = "1.3.7"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -75,7 +75,7 @@ LinearAlgebra = "1.10"
LoopVectorization = "0.12.171"
LuxCore = "1"
MKL = "0.7"
MLDataDevices = "1.2"
MLDataDevices = "1.5"
Markdown = "1.10"
NNlib = "0.9.24"
Octavian = "0.3.28"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/Impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using Random: Random, AbstractRNG, rand!
using Statistics: Statistics, mean, var

using LuxCore: LuxCore
using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, XLADevice,
using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, ReactantDevice,
AbstractGPUDevice, AbstractDevice
using NNlib: NNlib, ConvDims

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end

conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims)

function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, XLADevice}},
function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, ReactantDevice}},
x′, weight′, cdims::ConvDims)
x, weight = get_conv_input_weight(x′, weight′)
return NNlib.conv(x, weight, cdims)
Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.2"
version = "1.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -57,7 +57,7 @@ MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2"
Reactant = "0.2.4"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Currently we provide support for the following backends:
3. `AMDGPUDevice`: `AMDGPU.jl` for AMD ROCM GPUs.
4. `MetalDevice`: `Metal.jl` for Apple Metal GPUs. **(Experimental)**
5. `oneAPIDevice`: `oneAPI.jl` for Intel GPUs. **(Experimental)**
6. `XLADevice`: `Reactant.jl` for XLA Support. **(Experimental)**
6. `ReactantDevice`: `Reactant.jl` for XLA Support. **(Experimental)**

## Updating to v1.0

Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module MLDataDevicesMLUtilsExt

using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, DeviceIterator
MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice)
for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval function (D::$(dev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
Expand Down
41 changes: 32 additions & 9 deletions lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,49 @@
module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice
using Reactant: Reactant, RArray
using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type
using Reactant: Reactant, XLA, RArray, ConcreteRArray, ConcreteRNumber, TracedRArray,
TracedRNumber

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true
MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true

# Default RNG: Forward to CPU, we will compile it
function MLDataDevices.default_device_rng(::XLADevice)
function MLDataDevices.default_device_rng(::ReactantDevice)
return MLDataDevices.default_device_rng(CPUDevice())
end

# Query Device from Array
Internal.get_device(::RArray) = XLADevice()
function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray})
client = XLA.client(x.data)
device = XLA.device(x.data)
return ReactantDevice(client, device)
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end

Internal.get_device_type(::RArray) = XLADevice
function Internal.get_device_type(
::Union{TracedRArray, TracedRNumber, ConcreteRArray, ConcreteRNumber})
return ReactantDevice
end

# unsafe_free!
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing
Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing

# Device Transfer
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x)
function Adapt.adapt_storage(
dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive})
client = dev.client === missing ? XLA.default_backend[] : dev.client
device = dev.device === missing ? nothing : dev.device
return ConcreteRArray(x; client, device)
end

end
5 changes: 3 additions & 2 deletions lib/MLDataDevices/src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ include("internal.jl")

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device, xla_device
export gpu_device, cpu_device
export xla_device, reactant_device

export CPUDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export XLADevice
export XLADevice, ReactantDevice
export get_device, get_device_type

export DeviceIterator
Expand Down
8 changes: 4 additions & 4 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
MetalDevice, oneAPIDevice, ReactantDevice, UnknownDevice,
supported_gpu_backends, GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
Expand All @@ -27,11 +27,11 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg)
end
end
get_device_name(::XLADevice) = "XLA"
get_triggerpkg_name(::XLADevice) = "Reactant"
get_device_name(::ReactantDevice) = "XLA"
get_triggerpkg_name(::ReactantDevice) = "Reactant"

for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing},
MetalDevice, oneAPIDevice, XLADevice)
MetalDevice, oneAPIDevice, ReactantDevice)
@eval get_device_id(::$(T)) = nothing
end

Expand Down
54 changes: 41 additions & 13 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,29 @@ end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end

# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end
@kwdef struct ReactantDevice{C, D} <: AbstractAcceleratorDevice
client::C = missing
device::D = missing
end

function Base.:(==)(x::ReactantDevice, y::ReactantDevice)
if x.client !== missing
y.client === missing && return false
x.client.client != y.client.client && return false
else
y.client !== missing && return false
end
if x.device !== missing
y.device === missing && return false
x.device.device != y.device.device && return false
else
y.device !== missing && return false
end
return true
end

# XXX: Deprecate in v2
const XLADevice = ReactantDevice

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end
Expand Down Expand Up @@ -189,27 +210,34 @@ Return a `CPUDevice` object which can be used to transfer data to CPU.
cpu_device() = CPUDevice()

"""
xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice}
reactant_device(;
force::Bool=false, client=missing, device=missing
) -> Union{ReactantDevice, CPUDevice}
Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`.
Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.
`client` and `device` are used to specify the client and particular device to use. If not
specified, then the default client and index are used.
!!! danger
This is an experimental feature and might change without deprecations
"""
function xla_device(; force::Bool=false)
msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \
this function. Defaulting to CPU."
if loaded(XLADevice)
functional(XLADevice) && return XLADevice()
msg = "`XLADevice` is loaded but not functional. Defaulting to CPU."
function reactant_device(; force::Bool=false, client=missing, device=missing)
msg = "`ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \
calling this function. Defaulting to CPU."
if loaded(ReactantDevice)
functional(ReactantDevice) && return ReactantDevice(client, device)
msg = "`ReactantDevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
@warn msg maxlog=1
return cpu_device()
end

Base.@deprecate xla_device(; kwargs...) reactant_device(; kwargs...)

"""
default_device_rng(::AbstractDevice)
Expand Down Expand Up @@ -312,8 +340,8 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice}
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting."
T === CPUDevice &&
@warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting."
T === XLADevice &&
@warn "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting."
T === ReactantDevice &&
@warn "Setting device for `ReactantDevice` hasn't been implemented yet. Ignoring the device setting."
return
end

Expand Down Expand Up @@ -366,7 +394,7 @@ end
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng

for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval begin
function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG)
return default_device_rng(to)
Expand Down
Loading

0 comments on commit dc2885f

Please sign in to comment.