Skip to content

Commit

Permalink
Update catchstats.py: hide reading functions
Browse files Browse the repository at this point in the history
  • Loading branch information
casadoj committed Apr 5, 2024
1 parent 1697c44 commit c82bf1b
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/lisfloodutilities/catchstats/catchstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tqdm.auto import tqdm


def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:
def _read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:
"""It reads the input maps in NetCDF format from the input directory
Parameters:
Expand All @@ -48,7 +48,7 @@ def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:

try:
# for dynamic maps
ds = xr.open_mfdataset(filepaths, chunks='auto', parallel=True)
ds = xr.open_mfdataset(filepaths, chunks='auto', parallel=True) as ds:
# chunks is set to auto for general purpose processing
# it could be optimized depending on input NetCDF
except:
Expand All @@ -59,7 +59,7 @@ def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:

return ds

def read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:
def _read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:
"""It loads the catchment masks in NetCDF formal from the input directory
Parameters:
Expand Down Expand Up @@ -103,7 +103,7 @@ def read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:

return masks

def read_pixarea(pixarea: Union[str, Path]) -> xr.DataArray:
def _read_pixarea(pixarea: Union[str, Path]) -> xr.DataArray:
"""It reads the LISFLOOD pixel area static map
Parameters:
Expand Down Expand Up @@ -195,7 +195,7 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],

if output is not None:
fileout = output / f'{ID:04}.nc'
if fileout.exists() & ~overwrite:
if fileout.exists() and ~overwrite:
print(f'Output file {fileout} already exists. Moving forward to the next catchment')
continue

Expand All @@ -216,7 +216,7 @@ def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
# compute statistics
for var, stats in stats_dict.items():
for stat in stats:
if (stat in ['mean', 'sum', 'std', 'var']) & (weight is not None):
if (stat in ['mean', 'sum', 'std', 'var']) and (weight is not None):
x = getattr(weighted_maps, stat)(dim=[x_dim, y_dim])[var]
else:
x = getattr(masked_maps, stat)(dim=[x_dim, y_dim])[var]
Expand Down Expand Up @@ -256,13 +256,14 @@ def main(argv=sys.argv):

args = parser.parse_args()

maps = read_inputmaps(args.input)
masks = read_masks(args.mask)
if args.area is not None:
weight = read_pixarea(args.area)
else:
weight = None
catchment_statistics(maps, masks, args.statistic, weight=weight, output=args.output, overwrite=args.overwrite)
try:
maps = _read_inputmaps(args.input)
masks = _read_masks(args.mask)
weight = _read_pixarea(args.area) if args.area is not None else None
catchment_statistics(maps, masks, args.statistic, weight=weight, output=args.output, overwrite=args.overwrite)
except Exception as e:
print(f'ERROR: {e}')
sys.exit(1)

def main_script():
sys.exit(main())
Expand Down

0 comments on commit c82bf1b

Please sign in to comment.