Skip to content

Commit

Permalink
Adding MPI test
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Nov 18, 2022
1 parent 5607194 commit b46613f
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 44 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
61 changes: 61 additions & 0 deletions test/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using MPI
using Enzyme
using Test

struct Context
x::Vector{Float64}
end

function halo(context)
x = context.x
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
requests = Vector{MPI.Request}()
if rank != 0
buf = @view x[1:1]
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
end
if rank != np-1
buf = @view x[end:end]
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
end
for request in requests
MPI.Wait(request)
end
return nothing
end

MPI.Init()
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
n = np*10
n1 = Int(round(rank / np * (n+np))) - rank
n2 = Int(round((rank + 1) / np * (n+np))) - rank
nl = rank == 0 ? n1+1 : n1
nr = rank == np-1 ? n2-1 : n2
nlocal = nr-nl+1
context = Context(zeros(nlocal))
fill!(context.x, Float64(rank))
halo(context)
if rank != 0
@test context.x[1] == Float64(rank-1)
end
if rank != np-1
@test context.x[end] == Float64(rank+1)
end

dcontext = Context(zeros(nlocal))
fill!(dcontext.x, Float64(rank))
autodiff(halo, Duplicated(context, dcontext))
MPI.Barrier(MPI.COMM_WORLD)
if rank != 0
@test dcontext.x[2] == Float64(rank + rank - 1)
end
if rank != np-1
@test dcontext.x[end-1] == Float64(rank + rank + 1)
end
if !isinteractive()
MPI.Finalize()
end
Loading

0 comments on commit b46613f

Please sign in to comment.