Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding MPI test
Browse files Browse the repository at this point in the history
michel2323 committed Jan 31, 2023
1 parent a556f2e commit 5e5f560
Showing 3 changed files with 118 additions and 44 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
61 changes: 61 additions & 0 deletions test/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using MPI
using Enzyme
using Test

struct Context
x::Vector{Float64}
end

function halo(context)
x = context.x
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
requests = Vector{MPI.Request}()
if rank != 0
buf = @view x[1:1]
push!(requests, MPI.Isend(x[2:2], MPI.COMM_WORLD; dest=rank-1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank-1, tag=0))
end
if rank != np-1
buf = @view x[end:end]
push!(requests, MPI.Isend(x[end-1:end-1], MPI.COMM_WORLD; dest=rank+1, tag=0))
push!(requests, MPI.Irecv!(buf, MPI.COMM_WORLD; source=rank+1, tag=0))
end
for request in requests
MPI.Wait(request)
end
return nothing
end

MPI.Init()
np = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
n = np*10
n1 = Int(round(rank / np * (n+np))) - rank
n2 = Int(round((rank + 1) / np * (n+np))) - rank
nl = rank == 0 ? n1+1 : n1
nr = rank == np-1 ? n2-1 : n2
nlocal = nr-nl+1
context = Context(zeros(nlocal))
fill!(context.x, Float64(rank))
halo(context)
if rank != 0
@test context.x[1] == Float64(rank-1)
end
if rank != np-1
@test context.x[end] == Float64(rank+1)
end

dcontext = Context(zeros(nlocal))
fill!(dcontext.x, Float64(rank))
autodiff(halo, Duplicated(context, dcontext))
MPI.Barrier(MPI.COMM_WORLD)
if rank != 0
@test dcontext.x[2] == Float64(rank + rank - 1)
end
if rank != np-1
@test dcontext.x[end-1] == Float64(rank + rank + 1)
end
if !isinteractive()
MPI.Finalize()
end
100 changes: 56 additions & 44 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -14,9 +14,11 @@ using FiniteDifferences
using ForwardDiff
using Statistics
using LinearAlgebra
using MPI

using Enzyme_jll
@info "Testing against" Enzyme_jll.libEnzyme
@testset "Testing Enzyme.jl" begin

# Test against FiniteDifferences
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
@@ -25,8 +27,8 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs..
else
@test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...)
end
rm = ∂x

rm = ∂x
if typeof(x) <: Integer
x = Float64(x)
end
@@ -66,19 +68,19 @@ f0(x) = 1.0 + x

@test forward(Active(2.0)) == (nothing,)
@test pullback(Active(2.0), 1.0, nothing) == (1.0,)

function mul2(x)
x[1] * x[2]
end
d = Duplicated([3.0, 5.0], [0.0, 0.0])

forward, pullback = Enzyme.Compiler.thunk(mul2, nothing, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1))
res = forward(d)
@test typeof(res[1]) == NamedTuple{(Symbol("1"), Symbol("2")), Tuple{Float64, Float64}}
pullback(d, 1.0, res[1])
@test d.dval[1] 5.0
@test d.dval[2] 3.0
@test d.dval[2] 3.0

d = Duplicated([3.0, 5.0], [0.0, 0.0])
forward, pullback = Enzyme.Compiler.thunk(vrec, nothing, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1))
res = forward(Const(Int(1)), d)
@@ -225,8 +227,8 @@ end
autodiff(Reverse, arsum, Active, Duplicated(inp, dinp))
@test inp Float64[1.0, 2.0]
@test dinp Float64[1.0, 1.0]
@test autodiff(Forward, arsum, Duplicated(inp, dinp))[1] 2.0

@test autodiff(Forward, arsum, Duplicated(inp, dinp))[1] 2.0
end

@testset "Advanced array tests" begin
@@ -238,7 +240,7 @@ end
autodiff(Reverse, arsum2, Active, Duplicated(inp, dinp))
@test inp Float64[1.0, 2.0]
@test dinp Float64[1.0, 1.0]

@test autodiff(Forward, arsum2, Duplicated(inp, dinp))[1] 2.0
end

@@ -255,12 +257,12 @@ end
@test autodiff(Reverse, f_dict, Duplicated(params, dparams), Active(5.0)) == (10.0,)
@test dparams[:var] == 5.0


mutable struct MD
v::Float64
d::Dict{Symbol, MD}
end

# TODO without Float64 on return
# there is a potential phi bug
function sum_rec(d::Dict{Symbol,MD})::Float64
@@ -275,20 +277,20 @@ end
par = Dict{Symbol, MD}()
par[:var] = MD(10.0, Dict{Symbol, MD}())
par[:sub] = MD(2.0, Dict{Symbol, MD}(:a=>MD(3.0, Dict{Symbol, MD}())))

dpar = Dict{Symbol, MD}()
dpar[:var] = MD(0.0, Dict{Symbol, MD}())
dpar[:sub] = MD(0.0, Dict{Symbol, MD}(:a=>MD(0.0, Dict{Symbol, MD}())))

# TODO
# autodiff(Reverse, sum_rec, Duplicated(par, dpar))
# @show par, dpar, sum_rec(par)
# @test dpar[:var].v ≈ 1.0
# @test dpar[:sub].v ≈ 1.0
# @test dpar[:sub].d[:a].v ≈ 1.0
# @test dpar[:var].v ≈ 1.0
# @test dpar[:sub].v ≈ 1.0
# @test dpar[:sub].d[:a].v ≈ 1.0
end

let
let
function loadsin(xp)
x = @inbounds xp[1]
@inbounds xp[1] = 0.0
@@ -385,7 +387,7 @@ end
end
return mean(a)
end

@test Enzyme.autodiff(Reverse, gc_copy, Active, Active(5.0))[1] 10
@test Enzyme.autodiff(Forward, gc_copy, Duplicated(5.0, 1.0))[1] 10
end
@@ -856,7 +858,7 @@ end
dx = [1.0]

Enzyme.autodiff(Reverse, invtest, Duplicated(x, dx))

@test 10.0 x[1]
@test 5.0 dx[1]
end
@@ -906,7 +908,7 @@ end
out[] = x*x
nothing
end

out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(10.0)
@@ -973,7 +975,7 @@ end
for i in 1:10
@test 1.0 fo[i]
end

@test_throws ErrorException autodiff(Forward, x->x, Active(2.1))
end

@@ -1001,12 +1003,12 @@ end
shadow_a_in = shadow_a_out

autodiff(Reverse, f!, Const, Duplicated(a_out, shadow_a_out), Duplicated(a_in, shadow_a_in))

@test shadow_a_in Float64[0.0, 1.0, 1.0, 2.0]
@test shadow_a_out Float64[0.0, 1.0, 1.0, 2.0]

autodiff(Forward, f!, Const, Duplicated(a_out, shadow_a_out), Duplicated(a_in, shadow_a_in))

@test shadow_a_in Float64[1.0, 1.0, 2.0, 2.0]
@test shadow_a_out Float64[1.0, 1.0, 2.0, 2.0]
end
@@ -1020,7 +1022,7 @@ end
end
@test 1.0 autodiff(Reverse, f_undef, false, Active(2.14))[1]
@test_throws Base.UndefVarError autodiff(Reverse, f_undef, true, Active(2.14))

@test 1.0 autodiff(Forward, f_undef, false, Duplicated(2.14, 1.0))[1]
@test_throws Base.UndefVarError autodiff(Forward, f_undef, true, Duplicated(2.14, 1.0))
end
@@ -1038,7 +1040,7 @@ end

@test 0.0 autodiff(Reverse, tobedifferentiated, true, Active(2.1))[1]
@test 0.0 autodiff(Forward, tobedifferentiated, true, Duplicated(2.1, 1.0))[1]

function tobedifferentiated2(cond, a)::Float64
if cond
a + t
@@ -1141,7 +1143,7 @@ end
if i == 1
continue
end
if knots[i] == last_knot
if knots[i] == last_knot
@warn knots[i]
@inbounds knots[i] *= knots[i]
else
@@ -1186,9 +1188,9 @@ end
@inbounds F2[1] * F2[2]
end
autodiff(Reverse, copytest, Duplicated(F, dF))
@test F [1.234, 5.678]
@test F [1.234, 5.678]
@test dF [3.0, 2.0]

@test 31.0 autodiff(Forward, copytest, Duplicated([2.0, 3.0], [7.0, 5.0]))[1]
end

@@ -1287,7 +1289,7 @@ end

GC.@preserve x y dx dy begin
autodiff(foo,
Duplicated(Base.unsafe_convert(Ptr{Cvoid}, x), Base.unsafe_convert(Ptr{Cvoid}, dx)),
Duplicated(Base.unsafe_convert(Ptr{Cvoid}, x), Base.unsafe_convert(Ptr{Cvoid}, dx)),
Duplicated(Base.unsafe_convert(Ptr{Cvoid}, y), Base.unsafe_convert(Ptr{Cvoid}, dy)))
end
end
@@ -1303,7 +1305,7 @@ end
# x = x::Float64
# 2 * x
# end

# function gf2(v::MyType, fld, fld2)
# x = getfield(v, fld)
# y = getfield(v, fld2)
@@ -1316,14 +1318,14 @@ end
# Enzyme.autodiff(gf, Active, Duplicated(x, dx), Const(:x))
# @test x.x ≈ 3.0
# @test dx.x ≈ 2.0

# x = MyType(3.0)
# dx = MyType(0.0)

# Enzyme.autodiff(gf2, Active, Duplicated(x, dx), Const(:x), Const(:x))
# @test x.x ≈ 3.0
# @test dx.x ≈ 2.0
#
#
# x = MyType(3.0)
# dx = MyType(0.0)
# dx2 = MyType(0.0)
@@ -1365,7 +1367,7 @@ end
@show x, dx, y, dy
@test dx [5.2, 7.3]
@test dy [2.5, 3.7]

f_exc(x) = sum(x*x)
y = [[1.0, 2.0] [3.0,4.0]]
f_x = zero.(y)
@@ -1442,7 +1444,7 @@ end
end
y, = Enzyme.autodiff(double_push,Active(1.0))
@test y == 1.0

function aloss(a, arr)
for i in 1:2500
push!(arr, a)
@@ -1470,7 +1472,7 @@ end
@test bres[1][1] 6.0
@test bres[1][2] 12.0
@test bres[1][3] 18.0

bres = autodiff(Forward, square, BatchDuplicatedNoNeed, BatchDuplicated(3.0 + 7.0im, (1.0+0im, 2.0+0im, 3.0+0im)))
@test bres[1][1] 6.0 + 14.0im
@test bres[1][2] 12.0 + 28.0im
@@ -1528,7 +1530,7 @@ end
[v[2], v[1]*v[1], v[1]*v[1]*v[1]]
end

jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(1))
jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(1))
@test size(jac) == (3, 2)
@test jac [ 0.0 1.0;
4.0 0.0;
@@ -1543,7 +1545,7 @@ end
@test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])
@test jac == ForwardDiff.jacobian(inout, [2.0, 3.0])

jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(2))
jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(2))
@test size(jac) == (3, 2)
@test jac [ 0.0 1.0;
4.0 0.0;
@@ -1573,7 +1575,7 @@ end
J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, Val(5))
J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, Val(5))
J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, Val(5))

J_f_1(A, x) = Enzyme.jacobian(Forward, θ -> f_test_1(A, θ), x)
J_f_2(A, x) = Enzyme.jacobian(Forward, θ -> f_test_2(A, θ), x)
J_f_3(u, A, x) = Enzyme.jacobian(Forward, θ -> f_test_3!(u, A, θ), x)
@@ -1597,7 +1599,7 @@ end
1.0 0.0 0.0 0.0 1.0 0.0;
1.0 0.0 0.0 0.0 0.0 1.0;
]

# Function fails verification in test/CI
# @test J_f_1(A, x) == [
# 1.0 1.0 0.0 0.0 0.0 0.0;
@@ -1680,7 +1682,7 @@ end

autodiff(rs, Duplicated(data, ddata))
@test ddata [3.0, 5.0, 2.0, 2.0]

data = Float64[1.,2.,3.,4.]
ddata = ones(4)
autodiff(Forward, rs, Duplicated(data, ddata))
@@ -1716,19 +1718,19 @@ end
end

dw = Enzyme.autodiff(loss, Active, Active(1.0), Const(x), Const(false))

@test x [3.0]
@test dw[1] 3.0

c = ones(3)
inner(e) = c .+ e
fres = Enzyme.autodiff(Enzyme.Forward, inner, Duplicated{Vector{Float64}}, Duplicated([0., 0., 0.], [1., 1., 1.]))[1]
@test c [1.0, 1.0, 1.0]
@test fres [1.0, 1.0, 1.0]
@test c [1.0, 1.0, 1.0]
@test fres [1.0, 1.0, 1.0]
end

@testset "Large dynamic tape" begin

function ldynloss(X, Y, ps, bs)
ll = 0.0f0
for (x, y) in zip(X, Y)
@@ -1751,3 +1753,13 @@ end
end

end
@testset "MPI" begin
testdir = @__DIR__
# Test parsing
include("mpi.jl")
mpiexec() do cmd
run(`$cmd -n 2 $(Base.julia_cmd()) --project=$testdir $testdir/mpi.jl`)
end
@test true
end
end # Enzyme.jl testset

0 comments on commit 5e5f560

Please sign in to comment.