Skip to content

Commit

Permalink
Lazy install and load accelerator packages
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 17, 2024
1 parent 563d193 commit ead9773
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 26 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand All @@ -143,4 +141,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTestItems
using ReTestItems, Pkg, Tests

const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
@info "Running tests for group: $LUX_TEST_GROUP"
Expand All @@ -11,8 +11,6 @@ else
end

# Distributed Tests
using Pkg, Test

if LUX_TEST_GROUP == "all" || LUX_TEST_GROUP == "distributed"
Pkg.add(["MPI", "NCCL"])
using MPI
Expand Down
24 changes: 15 additions & 9 deletions test/setup_modes.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
using Lux, LuxCUDA, LuxAMDGPU
using Lux, GPUArraysCore, Pkg

CUDA.allowscalar(false)
GPUArraysCore.allowscalar(false)

const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All")

if BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA"
Pkg.add("LuxCUDA")
using LuxCUDA
end

if BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU"
Pkg.add("LuxAMDGPU")
using LuxAMDGPU
end

cpu_testing() = BACKEND_GROUP == "All" || BACKEND_GROUP == "CPU"
cuda_testing() = (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && LuxCUDA.functional()
function amdgpu_testing()
Expand All @@ -12,14 +22,10 @@ end

const MODES = begin
# Mode, Array Type, Device Function, GPU?
cpu_mode = ("CPU", Array, LuxCPUDevice(), false)
cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true)
amdgpu_mode = ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true)

modes = []
cpu_testing() && push!(modes, cpu_mode)
cuda_testing() && push!(modes, cuda_mode)
amdgpu_testing() && push!(modes, amdgpu_mode)
cpu_testing() && push!(modes, ("CPU", Array, LuxCPUDevice(), false))
cuda_testing() && push!(modes, ("CUDA", CuArray, LuxCUDADevice(), true))
amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true))

modes
end
12 changes: 1 addition & 11 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,8 @@ get_stable_rng(seed=12345) = StableRNG(seed)

__display(args...) = (println(); display(args...))

# AMDGPU Specifics
function _rocRAND_functional()
try
get_default_rng("AMDGPU")
return true
catch
return false
end
end

export @jet, @test_gradients, check_approx
export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
get_stable_rng, __display, _rocRAND_functional
get_stable_rng, __display

end

0 comments on commit ead9773

Please sign in to comment.