Skip to content

Commit

Permalink
Clean-up viz step
Browse files Browse the repository at this point in the history
  • Loading branch information
cbegeman committed Oct 16, 2023
1 parent 9521393 commit 82758b7
Showing 1 changed file with 13 additions and 101 deletions.
114 changes: 13 additions & 101 deletions compass/ocean/tests/drying_slope/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray
import xarray as xr

from compass.step import Step

Expand Down Expand Up @@ -67,7 +67,7 @@ def run(self):
section = self.config['drying_slope_viz']
generate_movie = section.getboolean('generate_movie')

self._plot_ssh_validation()
self._plot_ssh_validation(times=self.times)
self._plot_ssh_time_series()
if generate_movie:
frames_per_second = section.getint('frames_per_second')
Expand All @@ -78,7 +78,9 @@ def run(self):
os.makedirs(os.path.join(os.getcwd(), outFolder))
except OSError:
pass
self._plot_ssh_validation_for_movie(outFolder=outFolder)
for tidx, itime in enumerate(np.linspace(0, 0.5, 5 * 12 + 1)):
self._plot_ssh_validation(times=[itime], tidx=tidx,
outFolder=outFolder)
self._images_to_movies(framesPerSecond=frames_per_second,
outFolder=outFolder, extension=movie_format)

Expand Down Expand Up @@ -111,7 +113,7 @@ def _plot_ssh_time_series(self, outFolder='.'):

for i in range(naxes):
ax = plt.subplot(naxes, 1, i + 1)
ds = xarray.open_dataset(ncFilename[i])
ds = xr.open_dataset(ncFilename[i])
ympas = ds.ssh.where(ds.tidalInputMask).mean('nCells').values
xmpas = np.linspace(0, 1.0, len(ds.xtime)) * 12.0
ax.plot(xmpas, ympas, marker='o', label='MPAS-O forward',
Expand All @@ -129,7 +131,7 @@ def _plot_ssh_time_series(self, outFolder='.'):

plt.close(fig)

def _plot_ssh_validation(self, outFolder='.'):
def _plot_ssh_validation(self, times, tidx=None, outFolder='.'):
"""
Plot ssh as a function of along-channel distance for all times for
which there is validation data
Expand All @@ -140,7 +142,6 @@ def _plot_ssh_validation(self, outFolder='.'):
locs = 0.92 - np.divide(locs, 11.)

damping_coeffs = self.damping_coeffs
times = self.times
datatypes = self.datatypes

if damping_coeffs is None:
Expand All @@ -160,7 +161,7 @@ def _plot_ssh_validation(self, outFolder='.'):

for i in range(naxes):
ax = plt.subplot(naxes, 1, i + 1)
ds = xarray.open_dataset(ncFilename[i])
ds = xr.open_dataset(ncFilename[i])
ds = ds.drop_vars(np.setdiff1d([j for j in ds.variables],
['yCell', 'ssh']))

Expand All @@ -179,7 +180,7 @@ def _plot_ssh_validation(self, outFolder='.'):
# to get right time slices
plottime = int((float(atime) / 0.2 + 1e-16) * 24.0)
ymean = ds.isel(Time=plottime).groupby('yCell').mean(
dim=xarray.ALL_DIMS)
dim=xr.ALL_DIMS)
x = ymean.yCell.values / 1000.0
y = ymean.ssh.values

Expand Down Expand Up @@ -207,101 +208,12 @@ def _plot_ssh_validation(self, outFolder='.'):
rotation='vertical')
fig.text(0.5, 0.02, 'Along channel distance (km)', ha='center')

fig.savefig(f'{outFolder}/ssh_depth_section.png', dpi=200)
filename = f'{outFolder}/ssh_depth_section'
if tidx is not None:
filename = f'{filename}_t{tidx}'
fig.savefig(filename, dpi=200, format='png')
plt.close(fig)

def _plot_ssh_validation_for_movie(self, outFolder='.'):
"""
Compare ssh along the channel at different time slices with the
analytical solution and ROMS results.
Parameters
----------
"""
colors = {'MPAS-O': 'k', 'analytical': 'b', 'ROMS': 'g'}

locs = [7.2, 2.2, 0.2, 1.2, 4.2, 9.3]
locs = 0.92 - np.divide(locs, 11.)

damping_coeffs = self.damping_coeffs
if damping_coeffs is None:
naxes = 1
nhandles = 1
ncFilename = ['output.nc']
else:
naxes = len(damping_coeffs)
nhandles = naxes + 2
ncFilename = [f'output_{damping_coeff}.nc'
for damping_coeff in damping_coeffs]

times = self.times
datatypes = self.datatypes

xBed = np.linspace(0, 25, 100)
yBed = 10.0 / 25.0 * xBed

ii = 0
# Plot profiles over the 12h simulation duration
for itime in np.linspace(0, 0.5, 5 * 12 + 1):

plottime = int((float(itime) / 0.2 + 1e-16) * 24.0)

fig, _ = plt.subplots(nrows=naxes, ncols=1, sharex=True)

for i in range(naxes):
ax = plt.subplot(naxes, 1, i + 1)
ds = xarray.open_dataset(ncFilename[i])
ds = ds.drop_vars(np.setdiff1d([j for j in ds.variables],
['yCell', 'ssh']))

# Plot MPAS-O snapshots
# factor of 1e- needed to account for annoying round-off issue
# to get right time slices
ymean = ds.isel(Time=plottime).groupby('yCell').mean(
dim=xarray.ALL_DIMS)
x = ymean.yCell.values / 1000.0
y = ymean.ssh.values
ax.plot(xBed, yBed, '-k', lw=3)
ax.plot(x, -y, label='MPAS-O', color=colors['MPAS-O'])

ax.set_ylim(-1, 11)
ax.set_xlim(0, 25)
ax.invert_yaxis()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(frameon=False, loc='lower left')
ax.set_title(f't = {itime:.3f} days')
if damping_coeffs is not None:
ax.text(0.5, 5, 'r = ' + str(damping_coeffs[i]))
# Plot comparison data
for atime, ay in zip(times, locs):
ax.text(1, ay, f'{atime} days', size=8,
transform=ax.transAxes)

for datatype in datatypes:
datafile = f'./r{damping_coeffs[i]}d{atime}-'\
f'{datatype.lower()}.csv'
data = pd.read_csv(datafile, header=None)
ax.scatter(data[0], data[1], marker='.',
color=colors[datatype], label=datatype)

h, l0 = ax.get_legend_handles_labels()
ax.legend(h[0:nhandles], l0[0:nhandles], frameon=False,
loc='lower left')
ax.set_title(f't = {itime:.3f} days')

ds.close()

fig.text(0.04, 0.5, 'Channel depth (m)', va='center',
rotation='vertical')
fig.text(0.5, 0.02, 'Along channel distance (km)', ha='center')

fig.savefig(f'{outFolder}/ssh_depth_section_{ii:03d}.png', dpi=200)

plt.close(fig)
ii += 1

def _images_to_movies(self, outFolder='.', framesPerSecond=30,
extension='mp4', overwrite=True):
"""
Expand Down

0 comments on commit 82758b7

Please sign in to comment.