Skip to content

Commit

Permalink
Adding MPI test
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 authored and wsmoses committed Nov 20, 2023
1 parent 028134f commit 599abf9
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 22 deletions.
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
61 changes: 61 additions & 0 deletions test/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using MPI
using Enzyme
using Test

struct Context
x::Vector{Float64}
end

function halo(context)
x = context.x
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
requests = Vector{MPI.Request}()
if rank != 0
buf = @view x[1:1]
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
end
if rank != np-1
buf = @view x[end:end]
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
end
for request in requests
MPI.Wait(request)
end
return nothing
end

MPI.Init()
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
n = np*10
n1 = Int(round(rank / np * (n+np))) - rank
n2 = Int(round((rank + 1) / np * (n+np))) - rank
nl = rank == 0 ? n1+1 : n1
nr = rank == np-1 ? n2-1 : n2
nlocal = nr-nl+1
context = Context(zeros(nlocal))
fill!(context.x, Float64(rank))
halo(context)
if rank != 0
@test context.x[1] == Float64(rank-1)
end
if rank != np-1
@test context.x[end] == Float64(rank+1)
end

dcontext = Context(zeros(nlocal))
fill!(dcontext.x, Float64(rank))
autodiff(Reverse, halo, Duplicated(context, dcontext))
MPI.Barrier(MPI.COMM_WORLD)
if rank != 0
@test dcontext.x[2] == Float64(rank + rank - 1)
end
if rank != np-1
@test dcontext.x[end-1] == Float64(rank + rank + 1)
end
if !isinteractive()
MPI.Finalize()
end
53 changes: 32 additions & 21 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using Aqua
using Statistics
using LinearAlgebra
using InlineStrings
using MPI

using Enzyme_jll
@info "Testing against" Enzyme_jll.libEnzyme
Expand Down Expand Up @@ -236,7 +237,7 @@ make3() = (1.0, 2.0, 3.0)
test_scalar(x->rem(x, 1), 0.7)
test_scalar(x->rem2pi(x,RoundDown), 0.7)
test_scalar(x->fma(x,x+1,x/3), 2.3)

@test autodiff(Forward, sincos, Duplicated(1.0, 1.0))[1][1] cos(1.0)

@test autodiff(Reverse, (x)->log(x), Active(2.0)) == ((0.5,),)
Expand Down Expand Up @@ -588,7 +589,7 @@ end

bias = Float32[0.0;;;]
res = Enzyme.autodiff(Reverse, f, Active, Active(x[1]), Const(bias))

@test bias[1][1] 0.0
@test res[1][1] cos(x[1])
end
Expand Down Expand Up @@ -931,7 +932,7 @@ end

@inline function myquantile(v::AbstractVector, p::Real; alpha)
n = length(v)

m = 1.0 + p * (1.0 - alpha - 1.0)
aleph = n*p + oftype(p, m)
j = clamp(trunc(Int, aleph), 1, n-1)
Expand All @@ -944,7 +945,7 @@ end
a = @inbounds v[j]
b = @inbounds v[j + 1]
end

return a + γ*(b-a)
end

Expand Down Expand Up @@ -1166,18 +1167,18 @@ end
@test 1.0 Enzyme.autodiff(Forward, inactive_gen, Duplicated(1E4, 1.0))[1]

function whocallsmorethan30args(R)
temp = diag(R)
R_inv = [temp[1] 0. 0. 0. 0. 0.;
0. temp[2] 0. 0. 0. 0.;
0. 0. temp[3] 0. 0. 0.;
0. 0. 0. temp[4] 0. 0.;
0. 0. 0. 0. temp[5] 0.;
temp = diag(R)
R_inv = [temp[1] 0. 0. 0. 0. 0.;
0. temp[2] 0. 0. 0. 0.;
0. 0. temp[3] 0. 0. 0.;
0. 0. 0. temp[4] 0. 0.;
0. 0. 0. 0. temp[5] 0.;
]

return sum(R_inv)
end
R = zeros(6,6)

R = zeros(6,6)
dR = zeros(6, 6)
autodiff(Reverse, whocallsmorethan30args, Active, Duplicated(R, dR))

Expand Down Expand Up @@ -1880,7 +1881,7 @@ end
end
# TODO: Add test for NoShadowException
end

function indirectfltret(a)::DataType
a[] *= 2
return Float64
Expand Down Expand Up @@ -2386,7 +2387,7 @@ end
Enzyme.API.runtimeActivity!(false)
@test res[1] 0.2
# broken as the return of an apply generic is {primal, primal}
# but since the return is abstractfloat doing the
# but since the return is abstractfloat doing the
@static if VERSION v"1.9-" && !(VERSION v"1.10-" )
@test_broken res[2] 1.0
else
Expand Down Expand Up @@ -2456,6 +2457,16 @@ end
)
@test ad_eta[1] 0.0
end
@testset "MPI" begin
testdir = @__DIR__
# Test parsing
include("mpi.jl")
mpiexec() do cmd
run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`)
end
@test true
end


@testset "Tape Width" begin
struct Roo
Expand Down Expand Up @@ -2525,10 +2536,10 @@ end
Duplicated(inters, dinters),
)

@test dinters[1].k 0.1
@test dinters[1].t0 1.0
@test dinters[2].k 0.3
@test dinters[2].t0 2.0
@test dinters[1].k 0.1
@test dinters[1].t0 1.0
@test dinters[2].k 0.3
@test dinters[2].t0 2.0
end

@testset "Statistics" begin
Expand Down Expand Up @@ -2597,7 +2608,7 @@ end
y = A \ b
@test dA (-z * transpose(y))
@test db z

db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})
Expand All @@ -2613,7 +2624,7 @@ end

y = A \ b
@test db z

dA = zero(A)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})
Expand Down

0 comments on commit 599abf9

Please sign in to comment.