Skip to content

Commit

Permalink
Add addtional subpools for LTS and FBLTS at init
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrus89 committed Apr 5, 2024
1 parent 12368b9 commit 2f49d49
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,12 @@ subroutine ocn_time_integrator_fblts(domain, dt)!{{{
call mpas_pool_get_array(LTSPool, 'nEdgesInLTSRegion', nEdgesInLTSRegion)

! Create and retrieve additional pools for LTS
call mpas_pool_create_pool(tendSum3rdPool)
call mpas_pool_clone_pool(tendPool, tendSum3rdPool, 1)
call mpas_pool_create_pool(tendSlowPool)
call mpas_pool_clone_pool(tendPool, tendSlowPool, 1)
call mpas_pool_get_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % structs, 'tend_slow', tendSlowPool)

call mpas_pool_copy_pool(tendPool, tendSum3rdPool, 1)
call mpas_pool_copy_pool(tendPool, tendSlowPool, 1)

call mpas_pool_add_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_add_subpool(block % structs, 'tend_slow', tendSlowPool)

call mpas_pool_get_array(tendSlowPool, 'normalVelocity', &
normalVelocityTendSlow)
Expand Down Expand Up @@ -259,47 +258,6 @@ subroutine ocn_time_integrator_fblts(domain, dt)!{{{
normalVelocityTendSum3rd(:,:) = 0.0_RKIND
layerThicknessTendSum3rd(:,:) = 0.0_RKIND

if (associated(block % prev)) then
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_slow', tendSlowPool)
else
nullify(prevTendSum3rdPool)
nullify(prevTendSlowPool)
end if

if (associated(block % next)) then
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_3rd', nextTendSum3rdPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_slow', nextTendSlowPool)
else
nullify(nextTendSum3rdPool)
nullify(nextTendSlowPool)
end if

call mpas_pool_get_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % structs, 'tend_slow', tendSlowPool)

if (associated(prevTendSum3rdPool) .and. associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool, nextTendSum3rdPool)
else if (associated(prevTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool)
else if (associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool,nextPool=nextTendSum3rdPool)
else
call mpas_pool_link_pools(tendSum3rdPool)
end if

if (associated(prevTendSlowPool) .and. associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool, nextTendSlowPool)
else if (associated(prevTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool)
else if (associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool,nextPool=nextTendSlowPool)
else
call mpas_pool_link_pools(tendSlowPool)
end if

call mpas_pool_link_parinfo(block, tendSum3rdPool)
call mpas_pool_link_parinfo(block, tendSlowPool)

call mpas_timer_stop("FB_LTS time-step prep")

Expand Down Expand Up @@ -1367,12 +1325,6 @@ subroutine ocn_time_integrator_fblts(domain, dt)!{{{
call ocn_diagnostic_solve(dt, statePool, forcingPool, meshPool, &
verticalMeshPool, scratchPool, tracersPool, 2)

call mpas_pool_destroy_pool(tendSum3rdPool)
call mpas_pool_destroy_pool(tendSlowPool)

call mpas_pool_remove_subpool(block % structs, 'tend_sum_3rd')
call mpas_pool_remove_subpool(block % structs, 'tend_slow')

call mpas_timer_stop("FB_LTS cleanup")

end subroutine ocn_time_integrator_fblts!}}}
Expand Down Expand Up @@ -1415,6 +1367,15 @@ subroutine ocn_time_integration_fblts_init(domain)!{{{

type (mpas_pool_type), pointer :: &
LTSPool

type (mpas_pool_type), pointer :: &
tendSlowPool, &
tendSum3rdPool, &
prevTendSlowPool, nextTendSlowPool, &
prevTendSum3rdPool, nextTendSum3rdPool

type (mpas_pool_type), pointer :: &
tendPool

integer, dimension(:), allocatable :: &
isLTSRegionEdgeAssigned
Expand Down Expand Up @@ -1445,6 +1406,58 @@ subroutine ocn_time_integration_fblts_init(domain)!{{{
minMaxLTSRegion(2) = 2

block => domain % blocklist
call mpas_pool_get_subpool(block%structs, 'tend', tendPool)

call mpas_pool_create_pool(tendSum3rdPool)
call mpas_pool_clone_pool(tendPool, tendSum3rdPool, 1)
call mpas_pool_create_pool(tendSlowPool)
call mpas_pool_clone_pool(tendPool, tendSlowPool, 1)

call mpas_pool_add_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_add_subpool(block % structs, 'tend_slow', tendSlowPool)

if (associated(block % prev)) then
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_slow', tendSlowPool)
else
nullify(prevTendSum3rdPool)
nullify(prevTendSlowPool)
end if

if (associated(block % next)) then
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_3rd', nextTendSum3rdPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_slow', nextTendSlowPool)
else
nullify(nextTendSum3rdPool)
nullify(nextTendSlowPool)
end if

call mpas_pool_get_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % structs, 'tend_slow', tendSlowPool)

if (associated(prevTendSum3rdPool) .and. associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool, nextTendSum3rdPool)
else if (associated(prevTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool)
else if (associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool,nextPool=nextTendSum3rdPool)
else
call mpas_pool_link_pools(tendSum3rdPool)
end if

if (associated(prevTendSlowPool) .and. associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool, nextTendSlowPool)
else if (associated(prevTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool)
else if (associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool,nextPool=nextTendSlowPool)
else
call mpas_pool_link_pools(tendSlowPool)
end if

call mpas_pool_link_parinfo(block, tendSum3rdPool)
call mpas_pool_link_parinfo(block, tendSlowPool)

call mpas_pool_get_subpool(block % structs, 'LTS', LTSPool)

call mpas_pool_get_array(LTSPool, 'LTSRegion', LTSRegion)
Expand Down
204 changes: 108 additions & 96 deletions components/mpas-ocean/src/mode_forward/mpas_ocn_time_integration_lts.F
Original file line number Diff line number Diff line change
Expand Up @@ -254,20 +254,18 @@ subroutine ocn_time_integrator_lts(domain,dt)!{{{
call mpas_pool_get_array(LTSPool, 'nEdgesInLTSRegion', &
nEdgesInLTSRegion)

!--- Create additional pools for LTS
call mpas_pool_create_pool(tendSum1stPool)
call mpas_pool_clone_pool(tendPool, tendSum1stPool, 1)
call mpas_pool_create_pool(tendSum2ndPool)
call mpas_pool_clone_pool(tendPool, tendSum2ndPool, 1)
call mpas_pool_create_pool(tendSum3rdPool)
call mpas_pool_clone_pool(tendPool, tendSum3rdPool, 1)
call mpas_pool_create_pool(tendSlowPool)
call mpas_pool_clone_pool(tendPool, tendSlowPool, 1)
!--- Update additional pools for LTS
call mpas_pool_get_subpool(block % structs, 'tend_sum_1st', tendSum1stPool)
call mpas_pool_get_subpool(block % structs, 'tend_sum_2nd', tendSum2ndPool)
call mpas_pool_get_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % structs, 'tend_slow', tendSlowPool)


call mpas_pool_copy_pool(tendPool, tendSum1stPool, 1)
call mpas_pool_copy_pool(tendPool, tendSum2ndPool, 1)
call mpas_pool_copy_pool(tendPool, tendSum3rdPool, 1)
call mpas_pool_copy_pool(tendPool, tendSlowPool, 1)

call mpas_pool_add_subpool(block % structs, 'tend_sum_1st', tendSum1stPool)
call mpas_pool_add_subpool(block % structs, 'tend_sum_2nd', tendSum2ndPool)
call mpas_pool_add_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_add_subpool(block % structs, 'tend_slow', tendSlowPool)

call mpas_pool_get_array(tendSlowPool, 'normalVelocity', normalVelocityTendSlow)

Expand Down Expand Up @@ -307,79 +305,6 @@ subroutine ocn_time_integrator_lts(domain,dt)!{{{
normalVelocityTendSum3rd(:,:) = 0.0_RKIND
layerThicknessTendSum3rd(:,:) = 0.0_RKIND

if (associated(block % prev)) then
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_1st', tendSum1stPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_2nd', tendSum2ndPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_slow', tendSlowPool)
else
nullify(prevTendSum1stPool)
nullify(prevTendSum2ndPool)
nullify(prevTendSum3rdPool)
nullify(prevTendSlowPool)
end if

if (associated(block % next)) then
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_1st', nextTendSum1stPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_2nd', nextTendSum2ndPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_3rd', nextTendSum3rdPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_slow', nextTendSlowPool)
else
nullify(nextTendSum1stPool)
nullify(nextTendSum2ndPool)
nullify(nextTendSum3rdPool)
nullify(nextTendSlowPool)
end if

call mpas_pool_get_subpool(block % structs, 'tend_sum_1st', tendSum1stPool)
call mpas_pool_get_subpool(block % structs, 'tend_sum_2nd', tendSum2ndPool)
call mpas_pool_get_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % structs, 'tend_slow', tendSlowPool)

if (associated(prevTendSum1stPool) .and. associated(nextTendSum1stPool)) then
call mpas_pool_link_pools(tendSum1stPool, prevTendSum1stPool, nextTendSum1stPool)
else if (associated(prevTendSum1stPool)) then
call mpas_pool_link_pools(tendSum1stPool, prevTendSum1stPool)
else if (associated(nextTendSum1stPool)) then
call mpas_pool_link_pools(tendSum1stPool,nextPool=nextTendSum1stPool)
else
call mpas_pool_link_pools(tendSum1stPool)
end if

if (associated(prevTendSum2ndPool) .and. associated(nextTendSum2ndPool)) then
call mpas_pool_link_pools(tendSum2ndPool, prevTendSum2ndPool, nextTendSum2ndPool)
else if (associated(prevTendSum2ndPool)) then
call mpas_pool_link_pools(tendSum2ndPool, prevTendSum2ndPool)
else if (associated(nextTendSum2ndPool)) then
call mpas_pool_link_pools(tendSum2ndPool,nextPool=nextTendSum2ndPool)
else
call mpas_pool_link_pools(tendSum2ndPool)
end if

if (associated(prevTendSum3rdPool) .and. associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool, nextTendSum3rdPool)
else if (associated(prevTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool)
else if (associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool,nextPool=nextTendSum3rdPool)
else
call mpas_pool_link_pools(tendSum3rdPool)
end if

if (associated(prevTendSlowPool) .and. associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool, nextTendSlowPool)
else if (associated(prevTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool)
else if (associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool,nextPool=nextTendSlowPool)
else
call mpas_pool_link_pools(tendSlowPool)
end if

call mpas_pool_link_parinfo(block, tendSum1stPool)
call mpas_pool_link_parinfo(block, tendSum2ndPool)
call mpas_pool_link_parinfo(block, tendSum3rdPool)
call mpas_pool_link_parinfo(block, tendSlowPool)

call mpas_timer_stop("lts time-step prep")

Expand Down Expand Up @@ -1084,16 +1009,6 @@ subroutine ocn_time_integrator_lts(domain,dt)!{{{
! DIAGNOSTICS UPDATE ---
call ocn_diagnostic_solve(dt, statePool, forcingPool, meshPool, verticalMeshPool, scratchPool, tracersPool, 2)

call mpas_pool_destroy_pool(tendSum1stPool)
call mpas_pool_destroy_pool(tendSum2ndPool)
call mpas_pool_destroy_pool(tendSum3rdPool)
call mpas_pool_destroy_pool(tendSlowPool)

call mpas_pool_remove_subpool(block % structs, 'tend_sum_1st')
call mpas_pool_remove_subpool(block % structs, 'tend_sum_2nd')
call mpas_pool_remove_subpool(block % structs, 'tend_sum_3rd')
call mpas_pool_remove_subpool(block % structs, 'tend_slow')

call mpas_timer_stop("lts cleanup phase")


Expand Down Expand Up @@ -1123,6 +1038,16 @@ subroutine ocn_time_integration_lts_init(domain)!{{{

type (block_type), pointer :: block
type (mpas_pool_type), pointer :: LTSPool
type (mpas_pool_type), pointer :: tendPool
type (mpas_pool_type), pointer :: &
tendSlowPool, &
tendSum1stPool, &
tendSum2ndPool, &
tendSum3rdPool, &
prevTendSlowPool, nextTendSlowPool, &
prevTendSum1stPool, nextTendSum1stPool, &
prevTendSum2ndPool, nextTendSum2ndPool, &
prevTendSum3rdPool, nextTendSum3rdPool
integer, dimension(:), allocatable :: isLTSRegionEdgeAssigned
integer :: i, iCell, iEdge, iRegion, coarseRegions, fineRegions, fineRegionsM1
integer, dimension(:), pointer :: LTSRegion
Expand All @@ -1134,6 +1059,93 @@ subroutine ocn_time_integration_lts_init(domain)!{{{
minMaxLTSRegion(2) = 2

block => domain % blocklist

! Create additional pools
call mpas_pool_get_subpool(block%structs, 'tend', tendPool)

call mpas_pool_create_pool(tendSum1stPool)
call mpas_pool_clone_pool(tendPool, tendSum1stPool, 1)
call mpas_pool_create_pool(tendSum2ndPool)
call mpas_pool_clone_pool(tendPool, tendSum2ndPool, 1)
call mpas_pool_create_pool(tendSum3rdPool)
call mpas_pool_clone_pool(tendPool, tendSum3rdPool, 1)
call mpas_pool_create_pool(tendSlowPool)
call mpas_pool_clone_pool(tendPool, tendSlowPool, 1)

call mpas_pool_add_subpool(block % structs, 'tend_sum_1st', tendSum1stPool)
call mpas_pool_add_subpool(block % structs, 'tend_sum_2nd', tendSum2ndPool)
call mpas_pool_add_subpool(block % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_add_subpool(block % structs, 'tend_slow', tendSlowPool)

if (associated(block % prev)) then
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_1st', tendSum1stPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_2nd', tendSum2ndPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_sum_3rd', tendSum3rdPool)
call mpas_pool_get_subpool(block % prev % structs, 'tend_slow', tendSlowPool)
else
nullify(prevTendSum1stPool)
nullify(prevTendSum2ndPool)
nullify(prevTendSum3rdPool)
nullify(prevTendSlowPool)
end if

if (associated(block % next)) then
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_1st', nextTendSum1stPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_2nd', nextTendSum2ndPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_sum_3rd', nextTendSum3rdPool)
call mpas_pool_get_subpool(block % next % structs, 'tend_slow', nextTendSlowPool)
else
nullify(nextTendSum1stPool)
nullify(nextTendSum2ndPool)
nullify(nextTendSum3rdPool)
nullify(nextTendSlowPool)
end if

if (associated(prevTendSum1stPool) .and. associated(nextTendSum1stPool)) then
call mpas_pool_link_pools(tendSum1stPool, prevTendSum1stPool, nextTendSum1stPool)
else if (associated(prevTendSum1stPool)) then
call mpas_pool_link_pools(tendSum1stPool, prevTendSum1stPool)
else if (associated(nextTendSum1stPool)) then
call mpas_pool_link_pools(tendSum1stPool,nextPool=nextTendSum1stPool)
else
call mpas_pool_link_pools(tendSum1stPool)
end if

if (associated(prevTendSum2ndPool) .and. associated(nextTendSum2ndPool)) then
call mpas_pool_link_pools(tendSum2ndPool, prevTendSum2ndPool, nextTendSum2ndPool)
else if (associated(prevTendSum2ndPool)) then
call mpas_pool_link_pools(tendSum2ndPool, prevTendSum2ndPool)
else if (associated(nextTendSum2ndPool)) then
call mpas_pool_link_pools(tendSum2ndPool,nextPool=nextTendSum2ndPool)
else
call mpas_pool_link_pools(tendSum2ndPool)
end if

if (associated(prevTendSum3rdPool) .and. associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool, nextTendSum3rdPool)
else if (associated(prevTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool, prevTendSum3rdPool)
else if (associated(nextTendSum3rdPool)) then
call mpas_pool_link_pools(tendSum3rdPool,nextPool=nextTendSum3rdPool)
else
call mpas_pool_link_pools(tendSum3rdPool)
end if

if (associated(prevTendSlowPool) .and. associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool, nextTendSlowPool)
else if (associated(prevTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool, prevTendSlowPool)
else if (associated(nextTendSlowPool)) then
call mpas_pool_link_pools(tendSlowPool,nextPool=nextTendSlowPool)
else
call mpas_pool_link_pools(tendSlowPool)
end if

call mpas_pool_link_parinfo(block, tendSum1stPool)
call mpas_pool_link_parinfo(block, tendSum2ndPool)
call mpas_pool_link_parinfo(block, tendSum3rdPool)
call mpas_pool_link_parinfo(block, tendSlowPool)

call mpas_pool_get_subpool(block % structs, 'LTS', LTSPool)

call mpas_pool_get_array(LTSPool, 'LTSRegion', LTSRegion)
Expand Down

0 comments on commit 2f49d49

Please sign in to comment.