Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the possibility of prescribing weights before redistributing #70

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/OctreeDistributedDiscreteModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1599,9 +1599,23 @@ end
# Assumptions. Either:
# A) model.parts MPI tasks are included in parts_redistributed_model MPI tasks; or
# B) model.parts MPI tasks include parts_redistributed_model MPI tasks
const WeightsArrayType=Union{Nothing,MPIArray{<:Vector{<:Integer}}}
function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc,Dp},
parts_redistributed_model=model.parts) where {Dc,Dp}
parts_redistributed_model=model.parts;
weights::WeightsArrayType=nothing) where {Dc,Dp}
parts = (parts_redistributed_model === model.parts) ? model.parts : parts_redistributed_model
_weights=nothing
if (weights !== nothing)
Gridap.Helpers.@notimplementedif parts!==model.parts
_weights=map(model.dmodel.models,weights) do lmodel,weights
# The length of the local weights array has to match the number of
# cells in the model. This includes both owned and ghost cells.
# Only the flags for owned cells are actually taken into account.
@assert num_cells(lmodel)==length(weights)
convert(Vector{Cint},weights)
end
end

comm = parts.comm
if (GridapDistributed.i_am_in(model.parts.comm) || GridapDistributed.i_am_in(parts.comm))
if (parts_redistributed_model !== model.parts)
Expand All @@ -1610,7 +1624,7 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc
@assert A || B
end
if (parts_redistributed_model===model.parts || A)
_redistribute_parts_subseteq_parts_redistributed(model,parts_redistributed_model)
_redistribute_parts_subseteq_parts_redistributed(model,parts_redistributed_model,_weights)
else
_redistribute_parts_supset_parts_redistributed(model, parts_redistributed_model)
end
Expand All @@ -1619,7 +1633,9 @@ function GridapDistributed.redistribute(model::OctreeDistributedDiscreteModel{Dc
end
end

function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistributedDiscreteModel{Dc,Dp}, parts_redistributed_model) where {Dc,Dp}
function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistributedDiscreteModel{Dc,Dp},
parts_redistributed_model,
_weights::WeightsArrayType) where {Dc,Dp}
parts = (parts_redistributed_model === model.parts) ? model.parts : parts_redistributed_model
if (parts_redistributed_model === model.parts)
ptr_pXest_old = model.ptr_pXest
Expand All @@ -1631,7 +1647,15 @@ function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistribut
parts.comm)
end
ptr_pXest_new = pXest_copy(model.pXest_type, ptr_pXest_old)
pXest_partition!(model.pXest_type, ptr_pXest_new)
if (_weights !== nothing)
init_fn_callback_c = pXest_reset_callbacks(model.pXest_type)
map(_weights) do _weights
pXest_reset_data!(model.pXest_type, ptr_pXest_new, Cint(sizeof(Cint)), init_fn_callback_c, pointer(_weights))
end
pXest_partition!(model.pXest_type, ptr_pXest_new; weights_set=true)
else
pXest_partition!(model.pXest_type, ptr_pXest_new; weights_set=false)
end

# Compute RedistributeGlue
parts_snd, lids_snd, old2new = pXest_compute_migration_control_data(model.pXest_type,ptr_pXest_old,ptr_pXest_new)
Expand Down
51 changes: 45 additions & 6 deletions src/PXestTypeMethods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,31 @@ function pXest_balance!(::P8estType, ptr_pXest; k_2_1_balance=0)
end
end

function pXest_partition!(::P4estType, ptr_pXest)
p4est_partition(ptr_pXest, 0, C_NULL)
function pXest_partition!(pXest_type::P4estType, ptr_pXest; weights_set=false)
if (!weights_set)
p4est_partition(ptr_pXest, 0, C_NULL)
else
wcallback=pXest_weight_callback(pXest_type)
p4est_partition(ptr_pXest, 0, wcallback)
end
end

function pXest_partition!(::P6estType, ptr_pXest)
p6est_partition(ptr_pXest, C_NULL)
function pXest_partition!(pXest_type::P6estType, ptr_pXest; weights_set=false)
if (!weights_set)
p6est_partition(ptr_pXest, C_NULL)
else
wcallback=pXest_weight_callback(pXest_type)
p6est_partition(ptr_pXest, wcallback)
end
end

function pXest_partition!(::P8estType, ptr_pXest)
p8est_partition(ptr_pXest, 0, C_NULL)
function pXest_partition!(pXest_type::P8estType, ptr_pXest; weights_set=false)
if (!weights_set)
p8est_partition(ptr_pXest, 0, C_NULL)
else
wcallback=pXest_weight_callback(pXest_type)
p8est_partition(ptr_pXest, 0, wcallback)
end
end


Expand Down Expand Up @@ -805,6 +820,30 @@ function pXest_refine_callbacks(::P8estType)
refine_callback_c, refine_replace_callback_c
end

function pXest_weight_callback(::P4estType)
function weight_callback(::Ptr{p4est_t},
which_tree::p4est_topidx_t,
quadrant_ptr::Ptr{p4est_quadrant_t})
quadrant = quadrant_ptr[]
return unsafe_wrap(Array, Ptr{Cint}(quadrant.p.user_data), 1)[]
end
@cfunction($weight_callback, Cint, (Ptr{p4est_t}, p4est_topidx_t, Ptr{p4est_quadrant_t}))
end

function pXest_weight_callback(::P6estType)
Gridap.Helpers.@notimplemented
end

function pXest_weight_callback(::P8estType)
function weight_callback(::Ptr{p8est_t},
which_tree::p4est_topidx_t,
quadrant_ptr::Ptr{p8est_quadrant_t})
quadrant = quadrant_ptr[]
return unsafe_wrap(Array, Ptr{Cint}(quadrant.p.user_data), 1)[]
end
@cfunction($weight_callback, Cint, (Ptr{p8est_t}, p4est_topidx_t, Ptr{p8est_quadrant_t}))
end

function _unwrap_ghost_quadrants(::P4estType, pXest_ghost)
Ptr{p4est_quadrant_t}(pXest_ghost.ghosts.array)
end
Expand Down
9 changes: 8 additions & 1 deletion test/PoissonNonConformingOctreeModelsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,14 @@ module PoissonNonConformingOctreeModelsTests
e = uH - uhH
el2 = sqrt(sum( ∫( e⋅e )*dΩH ))

fmodel_red, red_glue=GridapDistributed.redistribute(fmodel);
weights=map(ranks,fmodel.dmodel.models) do rank,lmodel
if (rank%2==0)
zeros(Cint,num_cells(lmodel))
else
ones(Cint,num_cells(lmodel))
end
end
fmodel_red, red_glue=GridapDistributed.redistribute(fmodel,weights=weights);
Vhred=FESpace(fmodel_red,reffe,conformity=:H1;dirichlet_tags="boundary")
Uhred=TrialFESpace(Vhred,u)

Expand Down
Loading