diff --git a/test/Project.toml b/test/Project.toml index 7f90c1c1b27..797d4c384ef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,7 +9,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/test/mpi.jl b/test/mpi.jl new file mode 100644 index 00000000000..29887644722 --- /dev/null +++ b/test/mpi.jl @@ -0,0 +1,61 @@ +using MPI +using Enzyme +using Test + +struct Context + x::Vector{Float64} +end + +function halo(context) + x = context.x + np = MPI.Comm_size(MPI.COMM_WORLD) + rank = MPI.Comm_rank(MPI.COMM_WORLD) + requests = Vector{MPI.Request}() + if rank != 0 + buf = @view x[1:1] + push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0)) + push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0)) + end + if rank != np-1 + buf = @view x[end:end] + push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0)) + push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0)) + end + for request in requests + MPI.Wait(request) + end + return nothing +end + +MPI.Init() +np = MPI.Comm_size(MPI.COMM_WORLD) +rank = MPI.Comm_rank(MPI.COMM_WORLD) +n = np*10 +n1 = Int(round(rank / np * (n+np))) - rank +n2 = Int(round((rank + 1) / np * (n+np))) - rank +nl = rank == 0 ? n1+1 : n1 +nr = rank == np-1 ? n2-1 : n2 +nlocal = nr-nl+1 +context = Context(zeros(nlocal)) +fill!(context.x, Float64(rank)) +halo(context) +if rank != 0 + @test context.x[1] == Float64(rank-1) +end +if rank != np-1 + @test context.x[end] == Float64(rank+1) +end + +dcontext = Context(zeros(nlocal)) +fill!(dcontext.x, Float64(rank)) +autodiff(Reverse, halo, Duplicated(context, dcontext)) +MPI.Barrier(MPI.COMM_WORLD) +if rank != 0 + @test dcontext.x[2] == Float64(rank + rank - 1) +end +if rank != np-1 + @test dcontext.x[end-1] == Float64(rank + rank + 1) +end +if !isinteractive() + MPI.Finalize() +end diff --git a/test/runtests.jl b/test/runtests.jl index a796880650b..4f0b2f03e4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,7 @@ using ForwardDiff using Aqua using Statistics using LinearAlgebra -using InlineStrings +using MPI using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme @@ -236,7 +236,7 @@ make3() = (1.0, 2.0, 3.0) test_scalar(x->rem(x, 1), 0.7) test_scalar(x->rem2pi(x,RoundDown), 0.7) test_scalar(x->fma(x,x+1,x/3), 2.3) - + @test autodiff(Forward, sincos, Duplicated(1.0, 1.0))[1][1] ≈ cos(1.0) @test autodiff(Reverse, (x)->log(x), Active(2.0)) == ((0.5,),) @@ -588,7 +588,7 @@ end bias = Float32[0.0;;;] res = Enzyme.autodiff(Reverse, f, Active, Active(x[1]), Const(bias)) - + @test bias[1][1] ≈ 0.0 @test res[1][1] ≈ cos(x[1]) end @@ -931,7 +931,7 @@ end @inline function myquantile(v::AbstractVector, p::Real; alpha) n = length(v) - + m = 1.0 + p * (1.0 - alpha - 1.0) aleph = n*p + oftype(p, m) j = clamp(trunc(Int, aleph), 1, n-1) @@ -944,7 +944,7 @@ end a = @inbounds v[j] b = @inbounds v[j + 1] end - + return a + γ*(b-a) end @@ -1166,18 +1166,18 @@ end @test 1.0 ≈ Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1] function whocallsmorethan30args(R) - temp = diag(R) - R_inv = [temp[1] 0. 0. 0. 0. 0.; - 0. temp[2] 0. 0. 0. 0.; - 0. 0. temp[3] 0. 0. 0.; - 0. 0. 0. temp[4] 0. 0.; - 0. 0. 0. 0. temp[5] 0.; + temp = diag(R) + R_inv = [temp[1] 0. 0. 0. 0. 0.; + 0. temp[2] 0. 0. 0. 0.; + 0. 0. temp[3] 0. 0. 0.; + 0. 0. 0. temp[4] 0. 0.; + 0. 0. 0. 0. temp[5] 0.; ] - + return sum(R_inv) end - - R = zeros(6,6) + + R = zeros(6,6) dR = zeros(6, 6) autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR)) @@ -1845,7 +1845,7 @@ end end # TODO: Add test for NoShadowException end - + function indirectfltret(a)::DataType a[] *= 2 return Float64 @@ -2313,7 +2313,7 @@ end Enzyme.API.runtimeActivity!(false) @test res[1] ≈ 0.2 # broken as the return of an apply generic is {primal, primal} - # but since the return is abstractfloat doing the + # but since the return is abstractfloat doing the @static if VERSION ≥ v"1.9-" && !(VERSION ≥ v"1.10-" ) @test_broken res[2] ≈ 1.0 else @@ -2383,6 +2383,16 @@ end ) @test ad_eta[1] ≈ 0.0 end +@testset "MPI" begin + testdir = @__DIR__ + # Test parsing + include("mpi.jl") + mpiexec() do cmd + run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`) + end + @test true +end + @testset "Tape Width" begin struct Roo @@ -2452,10 +2462,10 @@ end Duplicated(inters, dinters), ) - @test dinters[1].k ≈ 0.1 - @test dinters[1].t0 ≈ 1.0 - @test dinters[2].k ≈ 0.3 - @test dinters[2].t0 ≈ 2.0 + @test dinters[1].k ≈ 0.1 + @test dinters[1].t0 ≈ 1.0 + @test dinters[2].k ≈ 0.3 + @test dinters[2].t0 ≈ 2.0 end @testset "Statistics" begin @@ -2524,7 +2534,7 @@ end y = A \ b @test dA ≈ (-z * transpose(y)) @test db ≈ z - + db = zero(b) forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) @@ -2540,7 +2550,7 @@ end y = A \ b @test db ≈ z - + dA = zero(A) forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})