-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support passing in device and client to XLA (#1020)
* feat: support passing in device and client to XLA * feat: add == dispatch for XLADevice * refactor: rename XLADevice/xla_device to ReactantDevice/reactant_device * refactor: restrict Reactant to 0.2.4
- Loading branch information
Showing
19 changed files
with
140 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "LuxLib" | ||
uuid = "82251201-b29d-42c6-8e01-566dec8acb11" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "1.3.6" | ||
version = "1.3.7" | ||
|
||
[deps] | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
|
@@ -75,7 +75,7 @@ LinearAlgebra = "1.10" | |
LoopVectorization = "0.12.171" | ||
LuxCore = "1" | ||
MKL = "0.7" | ||
MLDataDevices = "1.2" | ||
MLDataDevices = "1.5" | ||
Markdown = "1.10" | ||
NNlib = "0.9.24" | ||
Octavian = "0.3.28" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "MLDataDevices" | ||
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "1.4.2" | ||
version = "1.5.0" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
|
@@ -57,7 +57,7 @@ MLUtils = "0.4.4" | |
Metal = "1" | ||
Preferences = "1.4" | ||
Random = "1.10" | ||
Reactant = "0.2" | ||
Reactant = "0.2.4" | ||
RecursiveArrayTools = "3.8" | ||
ReverseDiff = "1.15" | ||
SparseArrays = "1.10" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,49 @@ | ||
module MLDataDevicesReactantExt | ||
|
||
using Adapt: Adapt | ||
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice | ||
using Reactant: Reactant, RArray | ||
using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type | ||
using Reactant: Reactant, XLA, RArray, ConcreteRArray, ConcreteRNumber, TracedRArray, | ||
TracedRNumber | ||
|
||
MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true | ||
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true | ||
MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true | ||
MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true | ||
|
||
# Default RNG: Forward to CPU, we will compile it | ||
function MLDataDevices.default_device_rng(::XLADevice) | ||
function MLDataDevices.default_device_rng(::ReactantDevice) | ||
return MLDataDevices.default_device_rng(CPUDevice()) | ||
end | ||
|
||
# Query Device from Array | ||
Internal.get_device(::RArray) = XLADevice() | ||
function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray}) | ||
client = XLA.client(x.data) | ||
device = XLA.device(x.data) | ||
return ReactantDevice(client, device) | ||
end | ||
|
||
function Internal.get_device(::Union{TracedRArray, TracedRNumber}) | ||
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") | ||
end | ||
|
||
Internal.get_device_type(::RArray) = XLADevice | ||
function Internal.get_device_type( | ||
::Union{TracedRArray, TracedRNumber, ConcreteRArray, ConcreteRNumber}) | ||
return ReactantDevice | ||
end | ||
|
||
# unsafe_free! | ||
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing | ||
Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing | ||
|
||
# Device Transfer | ||
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x) | ||
function Adapt.adapt_storage( | ||
dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) | ||
@warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \ | ||
transfer this via CPU." maxlog=1 | ||
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x)) | ||
end | ||
|
||
function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive}) | ||
client = dev.client === missing ? XLA.default_backend[] : dev.client | ||
device = dev.device === missing ? nothing : dev.device | ||
return ConcreteRArray(x; client, device) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.