diff --git a/Project.toml b/Project.toml index 4a9d38f0..2e5024e6 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,6 @@ MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -64,7 +63,7 @@ Pkg = "1.10" Preferences = "1.4" Random = "1.10" ReTestItems = "1.23.1" -Reexport = "1" +Reexport = "1.2.2" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" StableRNGs = "1" @@ -90,6 +89,7 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Preferences = "21216c6a-2e73-6563-6e65-726566657250" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -98,4 +98,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 3a118263..423b312a 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -14,14 +14,11 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Random: Random, AbstractRNG, rand! -using Reexport: @reexport using StaticArraysCore: StaticArraysCore, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce -@reexport using NNlib - const CRC = ChainRulesCore const KA = KernelAbstractions diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index 3ee54836..7e15c7f2 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -118,7 +118,7 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays + using StaticArrays, NNlib x = @SArray rand(2, 4) weight = @SArray rand(3, 2) diff --git a/test/others/qa_tests.jl b/test/others/qa_tests.jl index b00fa347..11c6d5a4 100644 --- a/test/others/qa_tests.jl +++ b/test/others/qa_tests.jl @@ -1,5 +1,5 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore, EnzymeCore + using Aqua, ChainRulesCore, EnzymeCore, NNlib using EnzymeCore: EnzymeRules Aqua.test_all(LuxLib; ambiguities=false, piracies=false) diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index c0486ac6..6c869371 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices, DispatchDoctor -@reexport using LuxTestUtils, StableRNGs, Test, Zygote, Enzyme +@reexport using LuxTestUtils, StableRNGs, Test, Zygote, Enzyme, NNlib import LuxTestUtils: @jet, @test_gradients, check_approx LuxTestUtils.jet_target_modules!(["LuxLib"])