Skip to content

Commit

Permalink
Update sphere_transport for convergence
Browse files Browse the repository at this point in the history
  • Loading branch information
cbegeman committed Oct 16, 2024
1 parent 3e85ac4 commit 71eabe8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 56 deletions.
54 changes: 28 additions & 26 deletions polaris/ocean/tasks/sphere_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,28 +171,28 @@ def _setup_steps(self, refinement):
for idx, refinement_factor in enumerate(refinement_factors):
resolution = get_resolution_for_task(
config, refinement_factor, refinement=refinement)
if resolution not in self.resolutions:
base_mesh_step, mesh_name = add_spherical_base_mesh_step(
component, resolution, icosahedral)
analysis_dependencies['mesh'][refinement_factor] = base_mesh_step

name = f'{prefix}_init_{mesh_name}'
subdir = f'{sph_trans_dir}/init/{mesh_name}'
if self.include_viz:
symlink = f'init/{mesh_name}'
else:
symlink = None
if subdir in component.steps:
init_step = component.steps[subdir]
else:
init_step = Init(component=component, name=name,
subdir=subdir, base_mesh=base_mesh_step,
case_name=case_name)
init_step.set_shared_config(config, link=config_filename)
analysis_dependencies['init'][refinement_factor] = init_step
self.add_step(base_mesh_step, symlink=f'base_mesh/{mesh_name}')
self.add_step(init_step, symlink=symlink)
if resolution not in resolutions:
resolutions.append(resolution)
base_mesh_step, mesh_name = add_spherical_base_mesh_step(
component, resolution, icosahedral)
self.add_step(base_mesh_step, symlink=f'base_mesh/{mesh_name}')
analysis_dependencies['mesh'][resolution] = base_mesh_step

name = f'{prefix}_init_{mesh_name}'
subdir = f'{sph_trans_dir}/init/{mesh_name}'
if self.include_viz:
symlink = f'init/{mesh_name}'
else:
symlink = None
if subdir in component.steps:
init_step = component.steps[subdir]
else:
init_step = Init(component=component, name=name,
subdir=subdir, base_mesh=base_mesh_step,
case_name=case_name)
init_step.set_shared_config(config, link=config_filename)
self.add_step(init_step, symlink=symlink)
analysis_dependencies['init'][resolution] = init_step

timestep, _ = get_timestep_for_task(
config, refinement_factor, refinement=refinement)
Expand All @@ -217,7 +217,7 @@ def _setup_steps(self, refinement):
refinement=refinement)
forward_step.set_shared_config(config, link=config_filename)
self.add_step(forward_step, symlink=symlink)
analysis_dependencies['forward'][resolution] = forward_step
analysis_dependencies['forward'][refinement_factor] = forward_step

if self.include_viz:
with_viz_dir = f'{sph_trans_dir}/with_viz'
Expand Down Expand Up @@ -257,10 +257,11 @@ def _setup_steps(self, refinement):
step = component.steps[subdir]
else:
step = MixingAnalysis(component=component,
resolutions=resolutions,
icosahedral=icosahedral, subdir=subdir,
refinement_factors=refinement_factors,
case_name=case_name,
dependencies=analysis_dependencies)
dependencies=analysis_dependencies,
refinement=refinement)
step.set_shared_config(config, link=config_filename)
self.add_step(step, symlink=symlink)

Expand All @@ -274,9 +275,10 @@ def _setup_steps(self, refinement):
step = component.steps[subdir]
else:
step = FilamentAnalysis(component=component,
resolutions=resolutions,
refinement_factors=refinement_factors,
icosahedral=icosahedral, subdir=subdir,
case_name=case_name,
dependencies=analysis_dependencies)
dependencies=analysis_dependencies,
refinement=refinement)
step.set_shared_config(config, link=config_filename)
self.add_step(step, symlink=symlink)
35 changes: 20 additions & 15 deletions polaris/ocean/tasks/sphere_transport/filament_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from polaris import Step
from polaris.mpas import time_index_from_xtime
from polaris.ocean.convergence import get_resolution_for_task
from polaris.ocean.resolution import resolution_to_subdir
from polaris.viz import use_mplstyle

Expand All @@ -25,8 +26,8 @@ class FilamentAnalysis(Step):
case_name : str
The name of the test case
"""
def __init__(self, component, resolutions, icosahedral, subdir,
case_name, dependencies):
def __init__(self, component, refinement_factors, icosahedral, subdir,
case_name, dependencies, refinement='both'):
"""
Create the step
Expand All @@ -53,22 +54,22 @@ def __init__(self, component, resolutions, icosahedral, subdir,
"""
super().__init__(component=component, name='filament_analysis',
subdir=subdir)
self.resolutions = resolutions
self.refinement_factors = refinement_factors
self.refinement = refinement
self.case_name = case_name

for resolution in resolutions:
mesh_name = resolution_to_subdir(resolution)
base_mesh = dependencies['mesh'][resolution]
init = dependencies['init'][resolution]
forward = dependencies['forward'][resolution]
for refinement_factor in refinement_factors:
base_mesh = dependencies['mesh'][refinement_factor]
init = dependencies['init'][refinement_factor]
forward = dependencies['forward'][refinement_factor]
self.add_input_file(
filename=f'{mesh_name}_mesh.nc',
filename=f'mesh_r{refinement_factor:02g}.nc',
work_dir_target=f'{base_mesh.path}/base_mesh.nc')
self.add_input_file(
filename=f'{mesh_name}_init.nc',
filename=f'init_r{refinement_factor:02g}.nc',
work_dir_target=f'{init.path}/initial_state.nc')
self.add_input_file(
filename=f'{mesh_name}_output.nc',
filename=f'output_r{refinement_factor:02g}.nc',
work_dir_target=f'{forward.path}/output.nc')
self.add_output_file('filament.png')

Expand All @@ -77,7 +78,11 @@ def run(self):
Run this step of the test case
"""
plt.switch_backend('Agg')
resolutions = self.resolutions
resolutions = list()
for refinement_factor in self.refinement_factors:
resolution = get_resolution_for_task(
self.config, refinement_factor, self.refinement)
resolutions.append(resolution)
config = self.config
section = config[self.case_name]
eval_time = section.getfloat('filament_evaluation_time')
Expand All @@ -89,9 +94,9 @@ def run(self):
filament_norm = np.zeros((len(resolutions), num_tau))
use_mplstyle()
fig, ax = plt.subplots()
for i, resolution in enumerate(resolutions):
mesh_name = resolution_to_subdir(resolution)
ds = xr.open_dataset(f'{mesh_name}_output.nc')
for i, refinement_factor in enumerate(self.refinement_factors):
mesh_name = resolution_to_subdir(resolutions[i])
ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc')
tidx = time_index_from_xtime(ds.xtime.values,
eval_time * s_per_day)
tracer = ds[variable_name]
Expand Down
2 changes: 2 additions & 0 deletions polaris/ocean/tasks/sphere_transport/forward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ mpas-ocean:
- tracers
- mesh
- xtime
- velocityZonal
- velocityMeridional
- normalVelocity
- layerThickness
- refZMid
Expand Down
35 changes: 20 additions & 15 deletions polaris/ocean/tasks/sphere_transport/mixing_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from polaris import Step
from polaris.mpas import time_index_from_xtime
from polaris.ocean.convergence import get_resolution_for_task
from polaris.ocean.resolution import resolution_to_subdir
from polaris.viz import use_mplstyle

Expand All @@ -27,8 +28,8 @@ class MixingAnalysis(Step):
case_name : str
The name of the test case
"""
def __init__(self, component, resolutions, icosahedral, subdir,
case_name, dependencies):
def __init__(self, component, refinement_factors, icosahedral, subdir,
case_name, dependencies, refinement='both'):
"""
Create the step
Expand All @@ -55,22 +56,22 @@ def __init__(self, component, resolutions, icosahedral, subdir,
"""
super().__init__(component=component, name='mixing_analysis',
subdir=subdir)
self.resolutions = resolutions
self.refinement_factors = refinement_factors
self.refinement = refinement
self.case_name = case_name

for resolution in resolutions:
mesh_name = resolution_to_subdir(resolution)
base_mesh = dependencies['mesh'][resolution]
init = dependencies['init'][resolution]
forward = dependencies['forward'][resolution]
for refinement_factor in refinement_factors:
base_mesh = dependencies['mesh'][refinement_factor]
init = dependencies['init'][refinement_factor]
forward = dependencies['forward'][refinement_factor]
self.add_input_file(
filename=f'{mesh_name}_mesh.nc',
filename=f'mesh_r{refinement_factor:02g}.nc',
work_dir_target=f'{base_mesh.path}/base_mesh.nc')
self.add_input_file(
filename=f'{mesh_name}_init.nc',
filename=f'init_r{refinement_factor:02g}.nc',
work_dir_target=f'{init.path}/initial_state.nc')
self.add_input_file(
filename=f'{mesh_name}_output.nc',
filename=f'output_r{refinement_factor:02g}.nc',
work_dir_target=f'{forward.path}/output.nc')
self.add_output_file('triplots.png')

Expand All @@ -79,7 +80,11 @@ def run(self):
Run this step of the test case
"""
plt.switch_backend('Agg')
resolutions = self.resolutions
resolutions = list()
for refinement_factor in self.refinement_factors:
resolution = get_resolution_for_task(
self.config, refinement_factor, self.refinement)
resolutions.append(resolution)
config = self.config
section = config[self.case_name]
eval_time = section.getfloat('mixing_evaluation_time')
Expand All @@ -89,12 +94,12 @@ def run(self):
use_mplstyle()
fig, axes = plt.subplots(nrows=nrows, ncols=2, sharex=True,
sharey=True, figsize=(5.5, 7))
for i, resolution in enumerate(resolutions):
for i, refinement_factor in enumerate(self.refinement_factors):
ax = axes[int(i / 2), i % 2]
_init_triplot_axes(ax)
mesh_name = resolution_to_subdir(resolution)
mesh_name = resolution_to_subdir(resolutions[i])
ax.set(title=mesh_name)
ds = xr.open_dataset(f'{mesh_name}_output.nc')
ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc')
if i % 2 == 0:
ax.set_ylabel("tracer3")
if int(i / 2) == nrows - 1:
Expand Down

0 comments on commit 71eabe8

Please sign in to comment.