From 71eabe8a9ec21e5d9b50ecbf4a24faca443a89b3 Mon Sep 17 00:00:00 2001 From: Carolyn Begeman Date: Thu, 10 Oct 2024 21:18:28 -0500 Subject: [PATCH] Update sphere_transport for convergence --- .../ocean/tasks/sphere_transport/__init__.py | 54 ++++++++++--------- .../sphere_transport/filament_analysis.py | 35 ++++++------ .../ocean/tasks/sphere_transport/forward.yaml | 2 + .../tasks/sphere_transport/mixing_analysis.py | 35 ++++++------ 4 files changed, 70 insertions(+), 56 deletions(-) diff --git a/polaris/ocean/tasks/sphere_transport/__init__.py b/polaris/ocean/tasks/sphere_transport/__init__.py index 3596c5e95..e1aabb1e1 100644 --- a/polaris/ocean/tasks/sphere_transport/__init__.py +++ b/polaris/ocean/tasks/sphere_transport/__init__.py @@ -171,28 +171,28 @@ def _setup_steps(self, refinement): for idx, refinement_factor in enumerate(refinement_factors): resolution = get_resolution_for_task( config, refinement_factor, refinement=refinement) - if resolution not in self.resolutions: + base_mesh_step, mesh_name = add_spherical_base_mesh_step( + component, resolution, icosahedral) + analysis_dependencies['mesh'][refinement_factor] = base_mesh_step + + name = f'{prefix}_init_{mesh_name}' + subdir = f'{sph_trans_dir}/init/{mesh_name}' + if self.include_viz: + symlink = f'init/{mesh_name}' + else: + symlink = None + if subdir in component.steps: + init_step = component.steps[subdir] + else: + init_step = Init(component=component, name=name, + subdir=subdir, base_mesh=base_mesh_step, + case_name=case_name) + init_step.set_shared_config(config, link=config_filename) + analysis_dependencies['init'][refinement_factor] = init_step + self.add_step(base_mesh_step, symlink=f'base_mesh/{mesh_name}') + self.add_step(init_step, symlink=symlink) + if resolution not in resolutions: resolutions.append(resolution) - base_mesh_step, mesh_name = add_spherical_base_mesh_step( - component, resolution, icosahedral) - self.add_step(base_mesh_step, symlink=f'base_mesh/{mesh_name}') - analysis_dependencies['mesh'][resolution] = base_mesh_step - - name = f'{prefix}_init_{mesh_name}' - subdir = f'{sph_trans_dir}/init/{mesh_name}' - if self.include_viz: - symlink = f'init/{mesh_name}' - else: - symlink = None - if subdir in component.steps: - init_step = component.steps[subdir] - else: - init_step = Init(component=component, name=name, - subdir=subdir, base_mesh=base_mesh_step, - case_name=case_name) - init_step.set_shared_config(config, link=config_filename) - self.add_step(init_step, symlink=symlink) - analysis_dependencies['init'][resolution] = init_step timestep, _ = get_timestep_for_task( config, refinement_factor, refinement=refinement) @@ -217,7 +217,7 @@ def _setup_steps(self, refinement): refinement=refinement) forward_step.set_shared_config(config, link=config_filename) self.add_step(forward_step, symlink=symlink) - analysis_dependencies['forward'][resolution] = forward_step + analysis_dependencies['forward'][refinement_factor] = forward_step if self.include_viz: with_viz_dir = f'{sph_trans_dir}/with_viz' @@ -257,10 +257,11 @@ def _setup_steps(self, refinement): step = component.steps[subdir] else: step = MixingAnalysis(component=component, - resolutions=resolutions, icosahedral=icosahedral, subdir=subdir, + refinement_factors=refinement_factors, case_name=case_name, - dependencies=analysis_dependencies) + dependencies=analysis_dependencies, + refinement=refinement) step.set_shared_config(config, link=config_filename) self.add_step(step, symlink=symlink) @@ -274,9 +275,10 @@ def _setup_steps(self, refinement): step = component.steps[subdir] else: step = FilamentAnalysis(component=component, - resolutions=resolutions, + refinement_factors=refinement_factors, icosahedral=icosahedral, subdir=subdir, case_name=case_name, - dependencies=analysis_dependencies) + dependencies=analysis_dependencies, + refinement=refinement) step.set_shared_config(config, link=config_filename) self.add_step(step, symlink=symlink) diff --git a/polaris/ocean/tasks/sphere_transport/filament_analysis.py b/polaris/ocean/tasks/sphere_transport/filament_analysis.py index a3f196439..f5a3d6895 100644 --- a/polaris/ocean/tasks/sphere_transport/filament_analysis.py +++ b/polaris/ocean/tasks/sphere_transport/filament_analysis.py @@ -5,6 +5,7 @@ from polaris import Step from polaris.mpas import time_index_from_xtime +from polaris.ocean.convergence import get_resolution_for_task from polaris.ocean.resolution import resolution_to_subdir from polaris.viz import use_mplstyle @@ -25,8 +26,8 @@ class FilamentAnalysis(Step): case_name : str The name of the test case """ - def __init__(self, component, resolutions, icosahedral, subdir, - case_name, dependencies): + def __init__(self, component, refinement_factors, icosahedral, subdir, + case_name, dependencies, refinement='both'): """ Create the step @@ -53,22 +54,22 @@ def __init__(self, component, resolutions, icosahedral, subdir, """ super().__init__(component=component, name='filament_analysis', subdir=subdir) - self.resolutions = resolutions + self.refinement_factors = refinement_factors + self.refinement = refinement self.case_name = case_name - for resolution in resolutions: - mesh_name = resolution_to_subdir(resolution) - base_mesh = dependencies['mesh'][resolution] - init = dependencies['init'][resolution] - forward = dependencies['forward'][resolution] + for refinement_factor in refinement_factors: + base_mesh = dependencies['mesh'][refinement_factor] + init = dependencies['init'][refinement_factor] + forward = dependencies['forward'][refinement_factor] self.add_input_file( - filename=f'{mesh_name}_mesh.nc', + filename=f'mesh_r{refinement_factor:02g}.nc', work_dir_target=f'{base_mesh.path}/base_mesh.nc') self.add_input_file( - filename=f'{mesh_name}_init.nc', + filename=f'init_r{refinement_factor:02g}.nc', work_dir_target=f'{init.path}/initial_state.nc') self.add_input_file( - filename=f'{mesh_name}_output.nc', + filename=f'output_r{refinement_factor:02g}.nc', work_dir_target=f'{forward.path}/output.nc') self.add_output_file('filament.png') @@ -77,7 +78,11 @@ def run(self): Run this step of the test case """ plt.switch_backend('Agg') - resolutions = self.resolutions + resolutions = list() + for refinement_factor in self.refinement_factors: + resolution = get_resolution_for_task( + self.config, refinement_factor, self.refinement) + resolutions.append(resolution) config = self.config section = config[self.case_name] eval_time = section.getfloat('filament_evaluation_time') @@ -89,9 +94,9 @@ def run(self): filament_norm = np.zeros((len(resolutions), num_tau)) use_mplstyle() fig, ax = plt.subplots() - for i, resolution in enumerate(resolutions): - mesh_name = resolution_to_subdir(resolution) - ds = xr.open_dataset(f'{mesh_name}_output.nc') + for i, refinement_factor in enumerate(self.refinement_factors): + mesh_name = resolution_to_subdir(resolutions[i]) + ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc') tidx = time_index_from_xtime(ds.xtime.values, eval_time * s_per_day) tracer = ds[variable_name] diff --git a/polaris/ocean/tasks/sphere_transport/forward.yaml b/polaris/ocean/tasks/sphere_transport/forward.yaml index e8811b654..be8a53ec0 100644 --- a/polaris/ocean/tasks/sphere_transport/forward.yaml +++ b/polaris/ocean/tasks/sphere_transport/forward.yaml @@ -49,6 +49,8 @@ mpas-ocean: - tracers - mesh - xtime + - velocityZonal + - velocityMeridional - normalVelocity - layerThickness - refZMid diff --git a/polaris/ocean/tasks/sphere_transport/mixing_analysis.py b/polaris/ocean/tasks/sphere_transport/mixing_analysis.py index 750518a7f..88ee06702 100644 --- a/polaris/ocean/tasks/sphere_transport/mixing_analysis.py +++ b/polaris/ocean/tasks/sphere_transport/mixing_analysis.py @@ -7,6 +7,7 @@ from polaris import Step from polaris.mpas import time_index_from_xtime +from polaris.ocean.convergence import get_resolution_for_task from polaris.ocean.resolution import resolution_to_subdir from polaris.viz import use_mplstyle @@ -27,8 +28,8 @@ class MixingAnalysis(Step): case_name : str The name of the test case """ - def __init__(self, component, resolutions, icosahedral, subdir, - case_name, dependencies): + def __init__(self, component, refinement_factors, icosahedral, subdir, + case_name, dependencies, refinement='both'): """ Create the step @@ -55,22 +56,22 @@ def __init__(self, component, resolutions, icosahedral, subdir, """ super().__init__(component=component, name='mixing_analysis', subdir=subdir) - self.resolutions = resolutions + self.refinement_factors = refinement_factors + self.refinement = refinement self.case_name = case_name - for resolution in resolutions: - mesh_name = resolution_to_subdir(resolution) - base_mesh = dependencies['mesh'][resolution] - init = dependencies['init'][resolution] - forward = dependencies['forward'][resolution] + for refinement_factor in refinement_factors: + base_mesh = dependencies['mesh'][refinement_factor] + init = dependencies['init'][refinement_factor] + forward = dependencies['forward'][refinement_factor] self.add_input_file( - filename=f'{mesh_name}_mesh.nc', + filename=f'mesh_r{refinement_factor:02g}.nc', work_dir_target=f'{base_mesh.path}/base_mesh.nc') self.add_input_file( - filename=f'{mesh_name}_init.nc', + filename=f'init_r{refinement_factor:02g}.nc', work_dir_target=f'{init.path}/initial_state.nc') self.add_input_file( - filename=f'{mesh_name}_output.nc', + filename=f'output_r{refinement_factor:02g}.nc', work_dir_target=f'{forward.path}/output.nc') self.add_output_file('triplots.png') @@ -79,7 +80,11 @@ def run(self): Run this step of the test case """ plt.switch_backend('Agg') - resolutions = self.resolutions + resolutions = list() + for refinement_factor in self.refinement_factors: + resolution = get_resolution_for_task( + self.config, refinement_factor, self.refinement) + resolutions.append(resolution) config = self.config section = config[self.case_name] eval_time = section.getfloat('mixing_evaluation_time') @@ -89,12 +94,12 @@ def run(self): use_mplstyle() fig, axes = plt.subplots(nrows=nrows, ncols=2, sharex=True, sharey=True, figsize=(5.5, 7)) - for i, resolution in enumerate(resolutions): + for i, refinement_factor in enumerate(self.refinement_factors): ax = axes[int(i / 2), i % 2] _init_triplot_axes(ax) - mesh_name = resolution_to_subdir(resolution) + mesh_name = resolution_to_subdir(resolutions[i]) ax.set(title=mesh_name) - ds = xr.open_dataset(f'{mesh_name}_output.nc') + ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc') if i % 2 == 0: ax.set_ylabel("tracer3") if int(i / 2) == nrows - 1: