Skip to content

Commit

Permalink
Tweak Distributed API (#3314)
Browse files Browse the repository at this point in the history
* Remove enable_overlapping from distributed kwarg

* Import Distributed after Solvers

* Dont switch import yet

* Implement default rank distribution

* Update validation scripts

* Update for new API

* More cleanup of distributed examples

* Restore 2D partitioning to hydrostatic turbulence example

* Implement Partition for Distributed architectures, messy state

* Fix issue with distributed tests for new API

* Add a comment about correctness checking

* Fix further tests

* Update validation/distributed_simulations/distributed_geostrophic_adjustment.jl

Co-authored-by: Simone Silvestri <[email protected]>

* Update validation/distributed_simulations/distributed_geostrophic_adjustment.jl

Co-authored-by: Simone Silvestri <[email protected]>

* tests fixxed

* bugfix

* try now

* fixed tests

* bugfix

* fixed tests

* Change synched to syncronized

* Update src/DistributedComputations/distributed_fields.jl

Co-authored-by: Navid C. Constantinou <[email protected]>

* Remove comment

* Apply suggestions from code review

---------

Co-authored-by: Simone Silvestri <[email protected]>
Co-authored-by: Simone Silvestri <[email protected]>
Co-authored-by: Navid C. Constantinou <[email protected]>
  • Loading branch information
4 people authored Oct 10, 2023
1 parent 4f26afb commit c823cb9
Show file tree
Hide file tree
Showing 24 changed files with 401 additions and 329 deletions.
3 changes: 1 addition & 2 deletions src/DistributedComputations/DistributedComputations.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DistributedComputations

export
Distributed, child_architecture, reconstruct_global_grid,
Distributed, Partition, child_architecture, reconstruct_global_grid,
inject_halo_communication_boundary_conditions,
DistributedFFTBasedPoissonSolver

Expand All @@ -18,6 +18,5 @@ include("halo_communication_bcs.jl")
include("distributed_fields.jl")
include("halo_communication.jl")
include("distributed_fft_based_poisson_solver.jl")
include("interleave_communication_and_computation.jl")

end # module
171 changes: 122 additions & 49 deletions src/DistributedComputations/distributed_architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,87 @@ using CUDA: ndevices, device!

import Oceananigans.Architectures: device, arch_array, array_type, child_architecture
import Oceananigans.Grids: zeros
import Oceananigans.Utils: sync_device!

struct Distributed{A, M, R, I, ρ, C, γ, T} <: AbstractArchitecture
child_architecture :: A
local_rank :: R
local_index :: I
ranks :: ρ
connectivity :: C
communicator :: γ
mpi_requests :: M
mpi_tag :: T
import Oceananigans.Utils: sync_device!, tupleit

#####
##### Partitioning
#####

# Form 3-tuple
regularize_size(sz::Tuple{<:Int, <:Int}, Rx, Ry) = (sz[1], sz[2], 1)

# Infer whether x- or y-dimension is indicated
function regularize_size(N::Int, Rx, Ry)
if Rx == 1
return (1, N, 1)
elseif Ry == 1
return (N, 1, 1)
else
throw(ArgumentError("We can't interpret the 1D size $N for 2D partitions!"))
end
end

regularize_size(sz::Array{<:Int, 1}) = regularize_size(sz[1])
regularize_size(sz::Array{<:Int, 2}) = regularize_size((sz[1], sz[2]))
regularize_size(sz::Tuple{<:Int}) = regularize_size(sz[1])

# For 3D partitions:
# regularize_size(sz::Tuple{<:Int, <:Int, <:Int}, Rx, Ry) = sz

struct Partition{S, Rx, Ry, Rz}
sizes :: S
function Partition{Rx, Ry, Rz}() where {Rx, Ry, Rz}
new{Nothing, Rx, Ry, Rz}(nothing)
end
function Partition(sizes::S) where S
Rx = size(sizes, 1)
Ry = size(sizes, 2)
Rz = size(sizes, 3)
return new{S, Rx, Ry, Rz}(sizes)
end
end

"""
Partition(Rx::Number, Ry::Number=1, Rz::Number=1)
Return `Partition` representing the division of a domain into
`Rx` parts in `x` and `Ry` parts in `y` and `Rz` parts in `z`,
where `x, y, z` are the first, second, and third dimension
respectively.
"""
Partition(Rx::Number, Ry::Number=1, Rz::Number=1) = Partition{Rx, Ry, Rz}()
Partition(lengths::Array{Int, 1}) = Partition(reshape(lengths, length(lengths), 1))
Base.size(::Partition{<:Any, Rx, Ry, Rz}) where {Rx, Ry, Rz} = (Rx, Ry, Rz)

struct Distributed{A, S, Δ, R, ρ, I, C, γ, M, T} <: AbstractArchitecture
child_architecture :: A
partition :: Δ
ranks :: R
local_rank :: ρ
local_index :: I
connectivity :: C
communicator :: γ
mpi_requests :: M
mpi_tag :: T

Distributed{S}(child_architecture :: A,
partition :: Δ,
ranks :: R,
local_rank :: ρ,
local_index :: I,
connectivity :: C,
communicator :: γ,
mpi_requests :: M,
mpi_tag :: T) where {S, A, Δ, R, ρ, I, C, γ, M, T} =
new{A, S, Δ, R, ρ, I, C, γ, M, T}(child_architecture,
partition,
ranks,
local_rank,
local_index,
connectivity,
communicator,
mpi_requests,
mpi_tag)
end

#####
Expand All @@ -23,14 +93,13 @@ end

"""
Distributed(child_architecture = CPU();
topology,
ranks,
devices = nothing,
communicator = MPI.COMM_WORLD)
topology,
partition,
devices = nothing,
communicator = MPI.COMM_WORLD)
Constructor for a distributed architecture that uses MPI for communications
Positional arguments
=================
Expand All @@ -42,14 +111,12 @@ Keyword arguments
- `topology` (required): the topology we want the grid to have. It is used to establish connectivity.
- `synchronized_communication`: if true, always use synchronized communication through ranks
- `ranks` (required): A 3-tuple `(Rx, Ry, Rz)` specifying the total processors in the `x`,
`y` and `z` direction. NOTE: support for distributed z direction is
limited, so `Rz = 1` is strongly suggested.
- enable_overlapped_computation: if `true` the prognostic halo communication will be overlapped
with tendency calculations, and the barotropic halo communication
with the implicit vertical solver (defaults to `true`)
- `devices`: `GPU` device linked to local rank. The GPU will be assigned based on the
local node rank as such `devices[node_rank]`. Make sure to run `--ntasks-per-node` <= `--gres=gpu`.
If `nothing`, the devices will be assigned automatically based on the available resources
Expand All @@ -58,56 +125,56 @@ Keyword arguments
if not for testing or developing. Change at your own risk!
"""
function Distributed(child_architecture = CPU();
topology,
ranks,
devices = nothing,
enable_overlapped_computation = true,
communicator = MPI.COMM_WORLD)

MPI.Initialized() || error("Must call MPI.Init() before constructing a MultiCPU.")

validate_tupled_argument(ranks, Int, "ranks")
topology,
communicator = MPI.COMM_WORLD,
devices = nothing,
synchronized_communication = false,
partition = Partition(MPI.Comm_size(communicator)))

if !(MPI.Initialized())
@info "MPI has not been initialized, so we are calling MPI.Init()".
MPI.Init()
end

ranks = size(partition)
Rx, Ry, Rz = ranks
total_ranks = Rx*Ry*Rz

total_ranks = Rx * Ry * Rz
mpi_ranks = MPI.Comm_size(communicator)
local_rank = MPI.Comm_rank(communicator)

# TODO: make this error refer to `partition` (user input) rather than `ranks`
if total_ranks != mpi_ranks
throw(ArgumentError("ranks=($Rx, $Ry, $Rz) [$total_ranks total] inconsistent " *
"with number of MPI ranks: $mpi_ranks."))
end

local_rank = MPI.Comm_rank(communicator)
local_index = rank2index(local_rank, Rx, Ry, Rz)
local_connectivity = RankConnectivity(local_index, ranks, topology)

A = typeof(child_architecture)
R = typeof(local_rank)
I = typeof(local_index)
ρ = typeof(ranks)
C = typeof(local_connectivity)
γ = typeof(communicator)

# Assign CUDA device if on GPUs
if child_architecture isa GPU
local_comm = MPI.Comm_split_type(communicator, MPI.COMM_TYPE_SHARED, local_rank)
node_rank = MPI.Comm_rank(local_comm)
isnothing(devices) ? device!(node_rank % ndevices()) : device!(devices[node_rank+1])
end

mpi_requests = enable_overlapped_computation ? MPI.Request[] : nothing

M = typeof(mpi_requests)
T = typeof(Ref(0))

return Distributed{A, M, R, I, ρ, C, γ, T}(child_architecture, local_rank, local_index, ranks, local_connectivity, communicator, mpi_requests, Ref(0))
mpi_requests = MPI.Request[]

return Distributed{synchronized_communication}(child_architecture,
partition,
ranks,
local_rank,
local_index,
local_connectivity,
communicator,
mpi_requests,
Ref(0))
end

const DistributedCPU = Distributed{CPU}
const DistributedGPU = Distributed{GPU}

const BlockingDistributed = Distributed{<:Any, <:Nothing}
const SynchronizedDistributed = Distributed{<:Any, true}

#####
##### All the architectures
Expand All @@ -121,9 +188,15 @@ array_type(arch::Distributed) = array_type(child_architecture(arch))
sync_device!(arch::Distributed) = sync_device!(arch.child_architecture)

cpu_architecture(arch::DistributedCPU) = arch
cpu_architecture(arch::DistributedGPU) =
Distributed(CPU(), arch.local_rank, arch.local_index, arch.ranks,
arch.connectivity, arch.communicator, arch.mpi_requests, arch.mpi_tag)
cpu_architecture(arch::DistributedGPU) = Distributed(CPU(),
arch.partition,
arch.ranks,
arch.local_rank,
arch.local_index,
arch.connectivity,
arch.communicator,
arch.mpi_requests,
arch.mpi_tag)

#####
##### Converting between index and MPI rank taking k as the fast index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import FFTW
import Oceananigans.Solvers: poisson_eigenvalues, solve!
import Oceananigans.Architectures: architecture


struct DistributedFFTBasedPoissonSolver{P, F, L, λ, S, I}
plan :: P
global_grid :: F
Expand Down Expand Up @@ -207,3 +206,4 @@ end
i, j, k = @index(Global, NTuple)
@inbounds ϕ[i, j, k] = real(ϕc[k, j, i])
end

28 changes: 26 additions & 2 deletions src/DistributedComputations/distributed_fields.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Oceananigans.Fields: Field, FieldBoundaryBuffers, location, set!
import Oceananigans.BoundaryConditions: fill_halo_regions!

using Oceananigans.Fields: validate_field_data, indices, validate_boundary_conditions, validate_indices
using Oceananigans.Fields: validate_field_data, indices, validate_boundary_conditions, validate_indices, recv_from_buffers!

function Field((LX, LY, LZ)::Tuple, grid::DistributedGrid, data, old_bcs, indices::Tuple, op, status)
arch = architecture(grid)
Expand Down Expand Up @@ -31,4 +31,28 @@ function set!(u::DistributedField, f::Function)
end

return u
end
end

"""
synchronize_communication!(field)
complete the halo passing of `field` among processors.
"""
function synchronize_communication!(field)
arch = architecture(field.grid)

# Wait for outstanding requests
if !isempty(arch.mpi_requests)
cooperative_waitall!(arch.mpi_requests)

# Reset MPI tag
arch.mpi_tag[] -= arch.mpi_tag[]

# Reset MPI requests
empty!(arch.mpi_requests)
end

recv_from_buffers!(field.data, field.boundary_buffers, field.grid)

return nothing
end
Loading

0 comments on commit c823cb9

Please sign in to comment.