From aea2ec769c1c8077d8b93ae36721f7b02cdc58c2 Mon Sep 17 00:00:00 2001
From: Carolyn Begeman <cbegeman@lanl.gov>
Date: Tue, 30 Jan 2024 17:45:01 -0600
Subject: [PATCH] Apply suggestions from code review

---
 .../ocean/tests/isomip_plus/initial_state.py  | 36 ++++++++++---------
 1 file changed, 19 insertions(+), 17 deletions(-)

diff --git a/compass/ocean/tests/isomip_plus/initial_state.py b/compass/ocean/tests/isomip_plus/initial_state.py
index 52183a715..19cd68482 100644
--- a/compass/ocean/tests/isomip_plus/initial_state.py
+++ b/compass/ocean/tests/isomip_plus/initial_state.py
@@ -248,16 +248,19 @@ def _plot(self, ds):
                                sectionY=section_y, dsMesh=ds, ds=ds,
                                showProgress=show_progress)
 
-        ds['oceanFracObserved'] = \
+        ssh = ds['ssh'].expand_dims(dim='Time', axis=0)
+        oceanFracObserved = \
             ds['oceanFracObserved'].expand_dims(dim='Time', axis=0)
-        ds['landIceThickness'] = \
+        oceanFracObserved = \
+            ds['oceanFracObserved'].expand_dims(dim='Time', axis=0)
+        landIcePressure = \
+            ds['landIcePressure'].expand_dims(dim='Time', axis=0)
+        landIceThickness = \
             ds['landIceThickness'].expand_dims(dim='Time', axis=0)
-        ds['landIceGroundedFraction'] = \
+        landIceGroundedFraction = \
             ds['landIceGroundedFraction'].expand_dims(dim='Time', axis=0)
-        ds['bottomDepth'] = ds['bottomDepth'].expand_dims(dim='Time', axis=0)
-        ds['totalColThickness'] = ds['ssh']
-        ds['totalColThickness'].values = \
-            ds['layerThickness'].sum(dim='nVertLevels')
+        bottomDepth = ds['bottomDepth'].expand_dims(dim='Time', axis=0)
+        totalColThickness = ds.layerThickness.sum(dim='nVertLevels')
         tol = 1e-10
         plotter.plot_horiz_series(ds.landIceMask,
                                   'landIceMask', 'landIceMask',
@@ -265,23 +268,22 @@ def _plot(self, ds):
         plotter.plot_horiz_series(ds.landIceFloatingMask,
                                   'landIceFloatingMask', 'landIceFloatingMask',
                                   True)
-        plotter.plot_horiz_series(ds.landIcePressure,
+        plotter.plot_horiz_series(landIcePressure,
                                   'landIcePressure', 'landIcePressure',
                                   True, vmin=1e5, vmax=1e7, cmap_scale='log')
-        plotter.plot_horiz_series(ds.landIceThickness,
+        plotter.plot_horiz_series(landIceThickness,
                                   'landIceThickness', 'landIceThickness',
                                   True, vmin=0, vmax=1e3)
-        plotter.plot_horiz_series(ds.ssh,
-                                  'ssh', 'ssh',
+        plotter.plot_horiz_series(ssh, 'ssh', 'ssh',
                                   True, vmin=-700, vmax=0)
-        plotter.plot_horiz_series(ds.bottomDepth,
+        plotter.plot_horiz_series(bottomDepth,
                                   'bottomDepth', 'bottomDepth',
                                   True, vmin=0, vmax=700)
-        plotter.plot_horiz_series(ds.ssh + ds.bottomDepth,
+        plotter.plot_horiz_series(ds.ssh + bottomDepth,
                                   'H', 'H', True,
                                   vmin=min_column_thickness + tol, vmax=700,
                                   cmap_set_under='r', cmap_scale='log')
-        plotter.plot_horiz_series(ds.totalColThickness,
+        plotter.plot_horiz_series(totalColThickness,
                                   'totalColThickness', 'totalColThickness',
                                   True, vmin=min_column_thickness + 1e-10,
                                   vmax=700, cmap_set_under='r')
@@ -296,13 +298,13 @@ def _plot(self, ds):
                                   True, vmin=0 + tol, vmax=1 - tol,
                                   cmap='cmo.balance',
                                   cmap_set_under='k', cmap_set_over='r')
-        plotter.plot_horiz_series(ds.landIceGroundedFraction,
+        plotter.plot_horiz_series(landIceGroundedFraction,
                                   'landIceGroundedFraction',
                                   'landIceGroundedFraction',
                                   True, vmin=0 + tol, vmax=1 - tol,
                                   cmap='cmo.balance',
                                   cmap_set_under='k', cmap_set_over='r')
-        plotter.plot_horiz_series(ds.oceanFracObserved,
+        plotter.plot_horiz_series(oceanFracObserved,
                                   'oceanFracObserved', 'oceanFracObserved',
                                   True, vmin=0 + tol, vmax=1 - tol,
                                   cmap='cmo.balance',
@@ -417,7 +419,7 @@ def _write_time_varying_forcing(self, ds_init, ice_density):
                 land_ice_pressure=land_ice_pressure,
                 modify_mask=ds_init.bottomDepth > 0.)
             land_ice_draft = np.maximum(land_ice_draft, -ds_init.bottomDepth)
-            land_ice_draft = land_ice_draft.transpose()
+            land_ice_draft = land_ice_draft.transpose('nCells', 'nVertLevels')
         else:
             land_ice_draft = ds_init.landIceDraft
             land_ice_pressure = ds_init.landIcePressure