diff --git a/src/OctreeDistributedDiscreteModels.jl b/src/OctreeDistributedDiscreteModels.jl index 7830182..5155daf 100644 --- a/src/OctreeDistributedDiscreteModels.jl +++ b/src/OctreeDistributedDiscreteModels.jl @@ -1180,7 +1180,8 @@ function Gridap.Adaptivity.refine(model::OctreeDistributedDiscreteModel{Dc,Dp}; if GridapDistributed.i_am_in(new_comm) if !isa(parts,Nothing) aux = ptr_new_pXest - ptr_new_pXest = _p4est_to_new_comm(ptr_new_pXest, + ptr_new_pXest = _pXest_to_new_comm(Val{Dc}, + ptr_new_pXest, model.ptr_pXest_connectivity, model.parts.comm, parts.comm) @@ -1232,10 +1233,9 @@ function Gridap.Adaptivity.refine(model::OctreeDistributedDiscreteModel{Dc,Dp}; end function Gridap.Adaptivity.adapt(model::OctreeDistributedDiscreteModel{Dc,Dp}, - refinement_and_coarsening_flags::MPIArray{<:Vector}; - parts=nothing) where {Dc,Dp} - - Gridap.Helpers.@notimplementedif parts!=nothing + refinement_and_coarsening_flags::MPIArray{<:Vector}; + parts=nothing) where {Dc,Dp} + Gridap.Helpers.@notimplementedif !isa(parts,Nothing) # Variables which are updated accross calls to init_fn_callback_2d current_quadrant_index_within_tree = Cint(0) @@ -1517,21 +1517,63 @@ function Gridap.Adaptivity.coarsen(model::OctreeDistributedDiscreteModel{Dc,Dp}) end end +function pXest_deflate_quadrants(::Type{Val{Dc}},ptr_pXest,data) where Dc + if Dc ==2 + P4est_wrapper.p4est_deflate_quadrants(ptr_pXest,data) + else + P4est_wrapper.p8est_deflate_quadrants(ptr_pXest,data) + end +end + +function pXest_comm_count_pertree(::Type{Val{Dc}},ptr_pXest,pertree) where Dc + if Dc == 2 + p4est_comm_count_pertree(ptr_pXest,pertree) + else + p8est_comm_count_pertree(ptr_pXest,pertree) + end +end + +function pXest_inflate(::Type{Val{Dc}}, + comm, + ptr_pXest_conn, + global_first_quadrant, + pertree, + quadrants, + data, + user_pointer) where Dc + if Dc == 2 + P4est_wrapper.p4est_inflate(comm, + ptr_pXest_conn, + global_first_quadrant, + pertree, + quadrants, + data, + user_pointer) + else + P4est_wrapper.p8est_inflate(comm, + ptr_pXest_conn, + global_first_quadrant, + pertree, + quadrants, + data, + user_pointer) + end +end # We have a p4est distributed among P processors. This function # instantiates the same among Q processors. -function _p4est_to_new_comm(ptr_pXest, ptr_pXest_conn, old_comm, new_comm) +function _pXest_to_new_comm(::Type{Val{Dc}},ptr_pXest, ptr_pXest_conn, old_comm, new_comm) where Dc A = is_included(old_comm,new_comm) # old \subset new (smaller to larger nparts) B = is_included(new_comm,old_comm) # old \supset new (larger to smaller nparts) @assert xor(A,B) if (A) - _p4est_to_new_comm_old_subset_new(ptr_pXest, ptr_pXest_conn, old_comm, new_comm) + _pXest_to_new_comm_old_subset_new(Val{Dc},ptr_pXest, ptr_pXest_conn, old_comm, new_comm) else - _p4est_to_new_comm_old_supset_new(ptr_pXest, ptr_pXest_conn, old_comm, new_comm) + _pXest_to_new_comm_old_supset_new(Val{Dc},ptr_pXest, ptr_pXest_conn, old_comm, new_comm) end end -function _p4est_to_new_comm_old_subset_new(ptr_pXest, ptr_pXest_conn, old_comm, new_comm) +function _pXest_to_new_comm_old_subset_new(::Type{Val{Dc}},ptr_pXest, ptr_pXest_conn, old_comm, new_comm) where Dc if (GridapDistributed.i_am_in(new_comm)) new_comm_num_parts = GridapDistributed.num_parts(new_comm) global_first_quadrant = Vector{P4est_wrapper.p4est_gloidx_t}(undef,new_comm_num_parts+1) @@ -1552,33 +1594,34 @@ function _p4est_to_new_comm_old_subset_new(ptr_pXest, ptr_pXest_conn, old_comm, global_first_quadrant[i] = old_global_first_quadrant[end] end MPI.Bcast!(global_first_quadrant,0,new_comm) - quadrants = P4est_wrapper.p4est_deflate_quadrants(ptr_pXest,C_NULL) - p4est_comm_count_pertree(ptr_pXest,pertree) + quadrants = pXest_deflate_quadrants(Val{Dc},ptr_pXest,C_NULL) + pXest_comm_count_pertree(Val{Dc},ptr_pXest,pertree) MPI.Bcast!(pertree,0,new_comm) else MPI.Bcast!(global_first_quadrant,0,new_comm) quadrants = sc_array_new_count(sizeof(p4est_quadrant_t), 0) MPI.Bcast!(pertree,0,new_comm) end - return P4est_wrapper.p4est_inflate(new_comm, - ptr_pXest_conn, - global_first_quadrant, - pertree, - quadrants, - C_NULL, - C_NULL) + return pXest_inflate(Val{Dc}, + new_comm, + ptr_pXest_conn, + global_first_quadrant, + pertree, + quadrants, + C_NULL, + C_NULL) else return nothing end end -function _p4est_to_new_comm_old_supset_new(ptr_pXest, ptr_pXest_conn, old_comm, new_comm) +function _pXest_to_new_comm_old_supset_new(::Type{Val{Dc}},ptr_pXest, ptr_pXest_conn, old_comm, new_comm) where Dc @assert GridapDistributed.i_am_in(old_comm) pXest = ptr_pXest[] pXest_conn = ptr_pXest_conn[] pertree = Vector{P4est_wrapper.p4est_gloidx_t}(undef,pXest_conn.num_trees+1) - p4est_comm_count_pertree(ptr_pXest,pertree) + pXest_comm_count_pertree(Val{Dc},ptr_pXest,pertree) if (GridapDistributed.i_am_in(new_comm)) new_comm_num_parts = GridapDistributed.num_parts(new_comm) @@ -1596,15 +1639,16 @@ function _p4est_to_new_comm_old_supset_new(ptr_pXest, ptr_pXest_conn, old_comm, for i = 1:length(new_global_first_quadrant) global_first_quadrant[i] = old_global_first_quadrant[i] end - quadrants = P4est_wrapper.p4est_deflate_quadrants(ptr_pXest,C_NULL) - - return P4est_wrapper.p4est_inflate(new_comm, - ptr_pXest_conn, - global_first_quadrant, - pertree, - quadrants, - C_NULL, - C_NULL) + quadrants = pXest_deflate_quadrants(Val{Dc},ptr_pXest,C_NULL) + + return pXest_inflate(Val{Dc}, + new_comm, + ptr_pXest_conn, + global_first_quadrant, + pertree, + quadrants, + C_NULL, + C_NULL) else return nothing end @@ -1779,7 +1823,8 @@ function _redistribute_parts_subseteq_parts_redistributed(model::OctreeDistribut if (parts_redistributed_model === model.parts) ptr_pXest_old = model.ptr_pXest else - ptr_pXest_old = _p4est_to_new_comm(model.ptr_pXest, + ptr_pXest_old = _pXest_to_new_comm(Val{Dc}, + model.ptr_pXest, model.ptr_pXest_connectivity, model.parts.comm, parts.comm) @@ -1861,7 +1906,8 @@ function _redistribute_parts_supset_parts_redistributed( # ptr_pXest_old is distributed over supset_comm # once created, ptr_pXest_new is distributed over subset_comm - ptr_pXest_new = _p4est_to_new_comm(ptr_pXest_old, + ptr_pXest_new = _pXest_to_new_comm(Val{Dc}, + ptr_pXest_old, model.ptr_pXest_connectivity, supset_comm, subset_comm)