Skip to content

Commit

Permalink
Merge pull request #61 from kthyng/improved_index_handling
Browse files Browse the repository at this point in the history
Improved index handling
  • Loading branch information
kthyng authored Sep 15, 2023
2 parents 93f6b2a + 6955f44 commit ef41d1e
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 20 deletions.
1 change: 1 addition & 0 deletions ci/environment-py3.10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- intake-axds
- intake-erddap
- intake
- nested_lookup
- tqdm
- codecov
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions ci/environment-py3.8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- intake-axds
- intake-erddap
- intake
- nested_lookup
- tqdm
- codecov
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions ci/environment-py3.9.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- intake-axds
- intake-erddap
- intake
- nested_lookup
- tqdm
- codecov
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies:
- docrep<=0.2.7
- furo
- nbsphinx
- nested_lookup
- jupyter_client
- myst-nb
- sphinx_pangeo_theme
Expand Down
6 changes: 5 additions & 1 deletion docs/whats_new.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# What's New

## unreleased version
## v0.9.0 (September 15, 2023)
* improved index handling

## v0.8.0 (September 11, 2023)

* `omsa.run` now saves the polygon found for the input model into the project directory.
* bunch of other changes
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- intake-axds
- intake-erddap
# - git+https://github.com/intake/intake
- nested_lookup
- tqdm
# # use these from github to include recent changes while packages are changing a lot
# # - git+git://github.com/axiom-data-science/extract_model#egg=extract_model
Expand Down
4 changes: 0 additions & 4 deletions ocean_model_skill_assessor/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def compute_stats(self):

def plot(self, featuretype=None, key_variable=None, **kwargs):
"""Plot."""
# import pdb; pdb.set_trace()
import xcmocean

# cmap and cmapdiff
Expand All @@ -65,7 +64,6 @@ def plot(self, featuretype=None, key_variable=None, **kwargs):
)
xlabel, ylabel = "", key_variable
# xname, yname, zname = self.dd.cf["T"].name, ["obs","model"], None
# import pdb; pdb.set_trace()
line.plot(
self.dd.reset_index(),
xname,
Expand Down Expand Up @@ -100,13 +98,11 @@ def plot(self, featuretype=None, key_variable=None, **kwargs):
# surface.plot(xname, yname, self.dd["obs"].squeeze(), self.dd["model"].squeeze(), **kwargs)
elif featuretype == "profile":
# use transpose so that index depth is plotted on y axis instead of x axis
# import pdb; pdb.set_trace()
xname, yname, zname = (
["obs", "model"],
self.dd.index.name or "index",
None,
)
# import pdb; pdb.set_trace()
xlabel, ylabel = key_variable, yname
# xname, yname, zname = ["obs","model"], self.dd.cf["Z"].name, None
line.plot(
Expand Down
20 changes: 8 additions & 12 deletions ocean_model_skill_assessor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,6 @@ def run(
# sort out depths between model and data
# 1 location: interpolate or nearest neighbor horizontally
# have it figure out depth
# import pdb; pdb.set_trace()
if ("Z" not in dfd.cf.axes) or no_Z:
Z = None
vertical_interp = False
Expand Down Expand Up @@ -755,7 +754,6 @@ def run(
dfd.cf["Z"].encoding = encoding

elif isinstance(dfd, (pd.DataFrame, pd.Series)):
# import pdb; pdb.set_trace()
if dsm[zkeym].attrs["positive"] == "up":
ilev = dfd.index.names.index(dfd.cf["Z"].name)
dfd.index = dfd.index.set_levels(
Expand Down Expand Up @@ -888,7 +886,11 @@ def run(
source_name,
)
if isinstance(dfd, pd.DataFrame):
dd = pd.read_csv(fname_aligned, parse_dates=True)
dd = pd.read_csv(fname_aligned) # , parse_dates=True)

if "T" in dd.cf:
dd[dd.cf["T"].name] = pd.to_datetime(dd.cf["T"])

# assume all columns except last two are index columns
# last two should be obs and model
dd = dd.set_index(list(dd.columns[:-2]))
Expand All @@ -900,7 +902,6 @@ def run(

# # # Combine and align the two time series of variable
# # with cfp_set_options(custom_criteria=vocab.vocab):
# import pdb; pdb.set_trace()
if isinstance(dfd, DataFrame) and key_variable_data not in dfd.cf:
msg = f"Key variable {key_variable_data} cannot be identified in dataset {source_name}. Skipping dataset.\n"
logger.warning(msg)
Expand Down Expand Up @@ -1318,8 +1319,8 @@ def run(
longitude=lons,
latitude=lats,
# T=slice(user_min_time, user_max_time),
# T=dfd.cf["T"].values,
T=None, # changed this because wasn't working with CTD profiles. Time interpolation happens during _align.
T=dfd.cf["T"].values,
# T=None, # changed this because wasn't working with CTD profiles. Time interpolation happens during _align.
make_time_series=True, # advanced index to make result time series instead of array
Z=Z,
vertical_interp=vertical_interp,
Expand All @@ -1328,6 +1329,7 @@ def run(
extrap=extrap,
extrap_val=None,
locstream=True,
# locstream_dim="z_rho",
weights=None,
mask=mask,
use_xoak=False,
Expand All @@ -1337,7 +1339,6 @@ def run(
xgcm_grid=grid,
return_info=True,
)
# import pdb; pdb.set_trace()
# save pickle of triangulation to project dir
if (
interpolate_horizontal
Expand Down Expand Up @@ -1399,7 +1400,6 @@ def run(
logger.info("Trying to drop vertical coordinates time series")
model_var = model_var.drop_vars(model_var.cf["vertical"].name)

# import pdb; pdb.set_trace()
# try rechunking to avoid killing kernel
if model_var.dims == (model_var.cf["T"].name,):
# for simple case of only time, just rechunk into pieces if no chunks
Expand All @@ -1409,7 +1409,6 @@ def run(

logger.info(f"Loading model output...")
model_var = model_var.compute()
# import pdb; pdb.set_trace()
# depths shouldn't need to be saved if interpolated since then will be a dimension
if Z is not None and not vertical_interp:
# find Z index
Expand All @@ -1430,7 +1429,6 @@ def run(
model_var["distance"].attrs["units"] = "km"
# model_var.attrs["distance_from_location_km"] = float(distance)
else:
# import pdb; pdb.set_trace()
# when lons/lats are function of time, add them back in
if dam.cf["longitude"].name not in model_var.coords:
# if model_var.ndim == 1 and len(model_var[model_var.dims[0]]) == lons.size:
Expand Down Expand Up @@ -1543,7 +1541,6 @@ def run(
# varnames += [dfd.cf.coordinates[col][0] for col in cols if col in dfd.cf.coordinates]
# varnames += [dfd.cf[key_variable_data].name]
# dd = _align(dfd[varnames], model_var, key_variable=key_variable_data)
# import pdb; pdb.set_trace()
dd = _align(dfd.cf[key_variable_data], model_var)
# read in from newly made file to make sure output is loaded
if isinstance(dd, pd.DataFrame):
Expand Down Expand Up @@ -1614,7 +1611,6 @@ def run(
title = f"{count}: {source_name}"
else:
title = f"{source_name}"
# import pdb; pdb.set_trace()
dd.omsa.plot(
title=title,
key_variable=key_variable,
Expand Down
1 change: 0 additions & 1 deletion ocean_model_skill_assessor/plot/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def plot(
ax.set_xlim(df[xname].min(), df[xname].max())
# df[xname].plot(ax=ax, label="observation", fontsize=fs, lw=lw, color=col_obs)
# df[yname].plot(ax=ax, label="model", fontsize=fs, lw=lw, color=col_model)

if stats is not None:
stat_sum = ""
types = ["bias", "corr", "ioa", "mse", "ss", "rmse"]
Expand Down
26 changes: 24 additions & 2 deletions ocean_model_skill_assessor/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def _align(
obs: Union[DataFrame, xr.DataArray],
model: Union[DataFrame, xr.DataArray],
already_aligned: Optional[bool] = False,
already_aligned: Optional[bool] = None,
) -> DataFrame:
"""Aligns obs and model signals in time and returns a combined DataFrame
Expand All @@ -38,6 +38,14 @@ def _align(
-----
Takes the obs times as the correct times to interpolate model to.
"""

# guess about being already_aligned
if already_aligned is None:
if len(obs) == len(model):
already_aligned = True
else:
already_aligned = False

if already_aligned:
if isinstance(obs, (Series, DataFrame)):
obs.name = "obs"
Expand All @@ -53,11 +61,25 @@ def _align(
for index in ["T", "Z", "latitude", "longitude"]:
# if index in obs, have as index for model too
if index in obs.cf.keys():
indices.append(model.cf[index].name)
# if index has only 1 unique value drop that index at this point
# for ilevel in and don't include for model indices

if (
len(
obs.index.get_level_values(
obs.cf[index].name
).unique()
)
> 1
):
indices.append(model.cf[index].name)
else:
obs.index = obs.index.droplevel(obs.cf[index].name)
# Indices have to match exactly to concat correctly
# so if lon/lat are in indices, need to have interpolated to those values
# instead of finding nearest neighbors
model = model.to_pandas().reset_index().set_index(indices)[var_name]

else:
model = model.squeeze().to_pandas()
model.name = "model"
Expand Down

0 comments on commit ef41d1e

Please sign in to comment.