From 82758b7571c9f1e8e253577918b9728c293116fa Mon Sep 17 00:00:00 2001 From: Carolyn Begeman Date: Mon, 16 Oct 2023 12:06:55 -0500 Subject: [PATCH] Clean-up viz step --- .../ocean/tests/drying_slope/viz/__init__.py | 114 ++---------------- 1 file changed, 13 insertions(+), 101 deletions(-) diff --git a/compass/ocean/tests/drying_slope/viz/__init__.py b/compass/ocean/tests/drying_slope/viz/__init__.py index 1c643a75b7..1b1daf7ab8 100644 --- a/compass/ocean/tests/drying_slope/viz/__init__.py +++ b/compass/ocean/tests/drying_slope/viz/__init__.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import xarray +import xarray as xr from compass.step import Step @@ -67,7 +67,7 @@ def run(self): section = self.config['drying_slope_viz'] generate_movie = section.getboolean('generate_movie') - self._plot_ssh_validation() + self._plot_ssh_validation(times=self.times) self._plot_ssh_time_series() if generate_movie: frames_per_second = section.getint('frames_per_second') @@ -78,7 +78,9 @@ def run(self): os.makedirs(os.path.join(os.getcwd(), outFolder)) except OSError: pass - self._plot_ssh_validation_for_movie(outFolder=outFolder) + for tidx, itime in enumerate(np.linspace(0, 0.5, 5 * 12 + 1)): + self._plot_ssh_validation(times=[itime], tidx=tidx, + outFolder=outFolder) self._images_to_movies(framesPerSecond=frames_per_second, outFolder=outFolder, extension=movie_format) @@ -111,7 +113,7 @@ def _plot_ssh_time_series(self, outFolder='.'): for i in range(naxes): ax = plt.subplot(naxes, 1, i + 1) - ds = xarray.open_dataset(ncFilename[i]) + ds = xr.open_dataset(ncFilename[i]) ympas = ds.ssh.where(ds.tidalInputMask).mean('nCells').values xmpas = np.linspace(0, 1.0, len(ds.xtime)) * 12.0 ax.plot(xmpas, ympas, marker='o', label='MPAS-O forward', @@ -129,7 +131,7 @@ def _plot_ssh_time_series(self, outFolder='.'): plt.close(fig) - def _plot_ssh_validation(self, outFolder='.'): + def _plot_ssh_validation(self, times, tidx=None, outFolder='.'): """ Plot ssh as a function of along-channel distance for all times for which there is validation data @@ -140,7 +142,6 @@ def _plot_ssh_validation(self, outFolder='.'): locs = 0.92 - np.divide(locs, 11.) damping_coeffs = self.damping_coeffs - times = self.times datatypes = self.datatypes if damping_coeffs is None: @@ -160,7 +161,7 @@ def _plot_ssh_validation(self, outFolder='.'): for i in range(naxes): ax = plt.subplot(naxes, 1, i + 1) - ds = xarray.open_dataset(ncFilename[i]) + ds = xr.open_dataset(ncFilename[i]) ds = ds.drop_vars(np.setdiff1d([j for j in ds.variables], ['yCell', 'ssh'])) @@ -179,7 +180,7 @@ def _plot_ssh_validation(self, outFolder='.'): # to get right time slices plottime = int((float(atime) / 0.2 + 1e-16) * 24.0) ymean = ds.isel(Time=plottime).groupby('yCell').mean( - dim=xarray.ALL_DIMS) + dim=xr.ALL_DIMS) x = ymean.yCell.values / 1000.0 y = ymean.ssh.values @@ -207,101 +208,12 @@ def _plot_ssh_validation(self, outFolder='.'): rotation='vertical') fig.text(0.5, 0.02, 'Along channel distance (km)', ha='center') - fig.savefig(f'{outFolder}/ssh_depth_section.png', dpi=200) + filename = f'{outFolder}/ssh_depth_section' + if tidx is not None: + filename = f'{filename}_t{tidx}' + fig.savefig(filename, dpi=200, format='png') plt.close(fig) - def _plot_ssh_validation_for_movie(self, outFolder='.'): - """ - Compare ssh along the channel at different time slices with the - analytical solution and ROMS results. - - Parameters - ---------- - - """ - colors = {'MPAS-O': 'k', 'analytical': 'b', 'ROMS': 'g'} - - locs = [7.2, 2.2, 0.2, 1.2, 4.2, 9.3] - locs = 0.92 - np.divide(locs, 11.) - - damping_coeffs = self.damping_coeffs - if damping_coeffs is None: - naxes = 1 - nhandles = 1 - ncFilename = ['output.nc'] - else: - naxes = len(damping_coeffs) - nhandles = naxes + 2 - ncFilename = [f'output_{damping_coeff}.nc' - for damping_coeff in damping_coeffs] - - times = self.times - datatypes = self.datatypes - - xBed = np.linspace(0, 25, 100) - yBed = 10.0 / 25.0 * xBed - - ii = 0 - # Plot profiles over the 12h simulation duration - for itime in np.linspace(0, 0.5, 5 * 12 + 1): - - plottime = int((float(itime) / 0.2 + 1e-16) * 24.0) - - fig, _ = plt.subplots(nrows=naxes, ncols=1, sharex=True) - - for i in range(naxes): - ax = plt.subplot(naxes, 1, i + 1) - ds = xarray.open_dataset(ncFilename[i]) - ds = ds.drop_vars(np.setdiff1d([j for j in ds.variables], - ['yCell', 'ssh'])) - - # Plot MPAS-O snapshots - # factor of 1e- needed to account for annoying round-off issue - # to get right time slices - ymean = ds.isel(Time=plottime).groupby('yCell').mean( - dim=xarray.ALL_DIMS) - x = ymean.yCell.values / 1000.0 - y = ymean.ssh.values - ax.plot(xBed, yBed, '-k', lw=3) - ax.plot(x, -y, label='MPAS-O', color=colors['MPAS-O']) - - ax.set_ylim(-1, 11) - ax.set_xlim(0, 25) - ax.invert_yaxis() - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.legend(frameon=False, loc='lower left') - ax.set_title(f't = {itime:.3f} days') - if damping_coeffs is not None: - ax.text(0.5, 5, 'r = ' + str(damping_coeffs[i])) - # Plot comparison data - for atime, ay in zip(times, locs): - ax.text(1, ay, f'{atime} days', size=8, - transform=ax.transAxes) - - for datatype in datatypes: - datafile = f'./r{damping_coeffs[i]}d{atime}-'\ - f'{datatype.lower()}.csv' - data = pd.read_csv(datafile, header=None) - ax.scatter(data[0], data[1], marker='.', - color=colors[datatype], label=datatype) - - h, l0 = ax.get_legend_handles_labels() - ax.legend(h[0:nhandles], l0[0:nhandles], frameon=False, - loc='lower left') - ax.set_title(f't = {itime:.3f} days') - - ds.close() - - fig.text(0.04, 0.5, 'Channel depth (m)', va='center', - rotation='vertical') - fig.text(0.5, 0.02, 'Along channel distance (km)', ha='center') - - fig.savefig(f'{outFolder}/ssh_depth_section_{ii:03d}.png', dpi=200) - - plt.close(fig) - ii += 1 - def _images_to_movies(self, outFolder='.', framesPerSecond=30, extension='mp4', overwrite=True): """