Skip to content

Commit

Permalink
Update variable comparison to take datasets
Browse files Browse the repository at this point in the history
This is needed if we are translating dimension and variable names
between native model names and common Polaris names.
  • Loading branch information
xylar committed Dec 5, 2024
1 parent 224658a commit 73eeb9f
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions polaris/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


def compare_variables(variables, filename1, filename2, logger, l1_norm=0.0,
l2_norm=0.0, linf_norm=0.0, quiet=True):
l2_norm=0.0, linf_norm=0.0, quiet=True, ds1=None,
ds2=None):
"""
compare variables in the two files
Expand Down Expand Up @@ -49,6 +50,16 @@ def compare_variables(variables, filename1, filename2, logger, l1_norm=0.0,
comparison is made. This is generally desirable when using nonzero
norm tolerance values.
ds1 : xarray.Dataset, optional
A dataset loaded from filename1. This may save time if the dataset is
already loaded and allows for calculations to be performed or variables
to be renamed if necessary.
ds2 : xarray.Dataset, optional
A dataset loaded from filename2. This may save time if the dataset is
already loaded and allows for calculations to be performed or variables
to be renamed if necessary.
Returns
-------
all_pass : bool
Expand All @@ -61,18 +72,16 @@ def compare_variables(variables, filename1, filename2, logger, l1_norm=0.0,
logger.error(f'File {filename} does not exist.')
return False

ds1 = xr.open_dataset(filename1)
ds2 = xr.open_dataset(filename2)
if ds1 is None:
ds1 = xr.open_dataset(filename1)

if ds2 is None:
ds2 = xr.open_dataset(filename2)

all_pass = True

for variable in variables:
all_found = True
for ds, filename in [(ds1, filename1), (ds2, filename2)]:
if variable not in ds:
logger.error(f'Variable {variable} not in {filename}.')
all_found = False
if not all_found:
if not _all_found(ds1, filename1, ds2, filename2, variable, logger):
all_pass = False
continue

Expand All @@ -85,13 +94,8 @@ def compare_variables(variables, filename1, filename2, logger, l1_norm=0.0,
all_pass = False
continue

all_match = True
for dim in da1.sizes:
if da1.sizes[dim] != da2.sizes[dim]:
logger.error(f"Field sizes for variable {variable} don't "
f"match files {filename1} and {filename2}.")
all_match = False
if not all_match:
if not _all_sizes_match(da1, filename1, da2, filename2, variable,
logger):
all_pass = False
continue

Expand Down Expand Up @@ -139,6 +143,27 @@ def compare_variables(variables, filename1, filename2, logger, l1_norm=0.0,
return all_pass


def _all_found(ds1, filename1, ds2, filename2, variable, logger):
""" Is the variable found in both datasets? """
all_found = True
for ds, filename in [(ds1, filename1), (ds2, filename2)]:
if variable not in ds:
logger.error(f'Variable {variable} not in {filename}.')
all_found = False
return all_found


def _all_sizes_match(da1, filename1, da2, filename2, variable, logger):
""" Do all dimension sizes match between the two variables? """
all_match = True
for dim in da1.sizes:
if da1.sizes[dim] != da2.sizes[dim]:
logger.error(f"Field sizes for variable {variable} don't "
f"match files {filename1} and {filename2}.")
all_match = False
return all_match


def _compute_norms(da1, da2, quiet, max_l1_norm, max_l2_norm, max_linf_norm,
time_index=None):
""" Compute norms between variables in two DataArrays """
Expand Down

0 comments on commit 73eeb9f

Please sign in to comment.