diff --git a/Project.toml b/Project.toml index e2843f3241..3b08f972d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.13" +version = "0.5.14" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -33,6 +33,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -44,6 +45,7 @@ LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] LuxFluxTransformExt = "Flux" LuxLuxAMDGPUExt = "LuxAMDGPU" LuxLuxCUDAExt = "LuxCUDA" +LuxMetalExt = "Metal" LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" @@ -66,6 +68,7 @@ LuxDeviceUtils = "0.1" LuxLib = "0.3" MacroTools = "0.5" Markdown = "<0.0.1, 1" +Metal = "0.5" Optimisers = "0.2, 0.3" PackageExtensionCompat = "1" Random = "<0.0.1, 1" @@ -87,6 +90,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/LuxMetalExt.jl b/ext/LuxMetalExt.jl new file mode 100644 index 0000000000..fb97682378 --- /dev/null +++ b/ext/LuxMetalExt.jl @@ -0,0 +1,23 @@ +module LuxMetalExt + +using Lux, LuxLib, Metal, Random + +@inline function Lux._init_hidden_state(rng::AbstractRNG, rnn, x::MtlArray) + return MtlArray(rnn.init_state(rng, rnn.out_dims, size(x, 2))) +end + +@inline function Lux._conv(x::SubArray{T, N, <:MtlArray}, weight, cdims) where {T, N} + return conv(copy(x), weight, cdims) +end + +@inline function Lux._conv_transpose(x::SubArray{T, N, <:MtlArray}, weight, + cdims) where {T, N} + return ∇conv_data(copy(x), weight, cdims) +end + +@inline function Lux._eachslice(x::MtlArray, ::Val{dims}) where {dims} + # FIXME: This is not efficient but Metal doesn't deal with views well + return [copy(selectdim(x, dims, i)) for i in axes(x, dims)] +end + +end