Skip to content

Commit

Permalink
Update catchstats.py to divide the process in several smaller functio…
Browse files Browse the repository at this point in the history
…ns, so function "catchment_statistics" can be imported in a script to be used as convenient
  • Loading branch information
casadoj committed Apr 5, 2024
1 parent e09b1f8 commit 1697c44
Showing 1 changed file with 173 additions and 106 deletions.
279 changes: 173 additions & 106 deletions src/lisfloodutilities/catchstats/catchstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,140 @@
import sys
import time
import xarray as xr
from typing import List, Union, Optional, Literal
from typing import Dict, List, Union, Optional, Literal
from tqdm.auto import tqdm


def catchment_statistics(inputmaps: Union[str, Path],
mask: Union[str, Path],
statistic: List[Literal['mean', 'sum', 'std', 'var', 'min', 'max', 'median', 'count']],
output: Union[str, Path],
pixarea: Optional[str] = None,
overwrite: bool = False
):
"""
Given a set of input maps and catchment masks, it computes catchment statistics.
def read_inputmaps(inputmaps: Union[str, Path]) -> xr.Dataset:
"""It reads the input maps in NetCDF format from the input directory
Parameters:
-----------
inputmaps: str or pathlib.Path
directory that contains the input NetCDF files whose statistics will be computed. These files can be static (withouth time dimenion) or dynamic (with time dimension)
Returns:
--------
ds: xr.Dataset
"""

inputmaps = Path(inputmaps)
if not inputmaps.is_dir():
print(f'ERROR: {inputmaps} is missing or not a directory!')
sys.exit(1)

filepaths = list(inputmaps.glob('*.nc'))
if not filepaths:
print(f'ERROR: No NetCDF files found in "{inputmaps}"')
sys.exit(2)

print(f'{len(filepaths)} input NetCDF files found in "{inputmaps}"')

try:
# for dynamic maps
ds = xr.open_mfdataset(filepaths, chunks='auto', parallel=True)
# chunks is set to auto for general purpose processing
# it could be optimized depending on input NetCDF
except:
# for static maps
ds = xr.Dataset({file.stem.split('_')[0]: xr.open_dataset(file)['Band1'] for file in filepaths})
if 'wgs_1984' in ds:
ds = ds.drop_vars('wgs_1984')

return ds

def read_masks(mask: Union[str, Path]) -> Dict[int, xr.DataArray]:
"""It loads the catchment masks in NetCDF formal from the input directory
Parameters:
-----------
mask: str or pathlib.Path
directory that contains the NetCDF files that define the catchment boundaries. These files can be the output of the `cutmaps` tool
statistic: list of strings
Returns:
--------
masks: dictionary of xr.DataArray
keys represent the catchment ID and the values boolean maps of the catchment
"""

# check masks
mask = Path(mask)
if not mask.is_dir():
print(f'ERROR: {mask} is not a directory!')
sys.exit(1)

maskpaths = list(mask.glob('*.nc'))
if not maskpaths:
print(f'ERROR: No NetCDF files found in "{mask}"')
sys.exit(2)

print(f'{len(maskpaths)} mask NetCDF files found in "{mask}"')

# load masks
masks = {}
for maskpath in maskpaths:
ID = int(maskpath.stem)
try:
try:
aoi = xr.open_dataset(maskpath)['Band1']
except:
aoi = xr.open_dataarray(maskpath)
aoi = xr.where(aoi.notnull(), 1, aoi)
masks[ID] = aoi
except Exception as e:
print(f'ERROR: The mask {maskpath} could not be read: {e}')
continue

return masks

def read_pixarea(pixarea: Union[str, Path]) -> xr.DataArray:
"""It reads the LISFLOOD pixel area static map
Parameters:
-----------
pixarea: string or Path
a NetCDF file with pixel area used to compute weighted statistics. It is specifically meant for geographic projection systems where the area of a pixel varies with latitude
Returns:
--------
weight: xr.DataArray
"""

pixarea = Path(pixarea)
if not pixarea.is_file():
print(f'ERROR: {pixarea} is not a file!')
sys.exit(1)

try:
weight = xr.open_dataset(pixarea)['Band1']
except Exception as e:
print(f'ERROR: The weighing map "{pixarea}" could not be loaded: {e}')
sys.exit(2)

return weight

def catchment_statistics(maps: Union[xr.DataArray, xr.Dataset],
masks: Dict[int, xr.DataArray],
statistic: Union[Literal['mean', 'median', 'std', 'min', 'max', 'count'], List[Literal['mean', 'median', 'std', 'min', 'max', 'count']]],
weight: Optional[xr.DataArray] = None,
output: Optional[Union[str, Path]] = None,
overwrite: bool = False
) -> Optional[xr.Dataset]:
"""
Given a set of input maps and catchment masks, it computes catchment statistics.
Parameters:
-----------
maps: xarray.DataArray or xarray.Dataset
map or set of maps from which catchment statistics will be computed. Library Rioxarray must have been used to define the coordinate reference system and the dimensions
masks: dictionary of xr.DataArray
a set of catchment masks. For isntance, the tool `cutmaps` in the library `lisflood-utilities` can be used
statistic: string or list of strings
statistics to be computed. Only some statistics are available: 'mean', 'sum', 'std', 'var', 'min', 'max', 'median', 'count'
output: str or pathlib.Path
directory where the resulting NetCDF files will be saved.
pixarea: optional or str
if provided, a NetCDF file with pixel area used to compute weighted statistics. It is specifically meant for geographic projection systems where the area of a pixel varies with latitude
weight: optional or xr.DataArray
map used to weight each pixel in "maps" before computing the statistics. It is meant to weight pixels by their different pixel area in geographic projections
output: optional, str or pathlib.Path
directory where the resulting NetCDF files will be saved. If not provided, the results are put out as a dictionary of xr.Dataset
overwrite: boolean
whether to overwrite or skip catchments whose output NetCDF file already exists. By default is False, so the catchment will be skipped
Expand All @@ -52,126 +160,81 @@ def catchment_statistics(inputmaps: Union[str, Path],
"""

start_time = time.perf_counter()

# output directory
output = Path(output)
output.mkdir(parents=True, exist_ok=True)


if isinstance(maps, xr.DataArray):
maps = xr.Dataset({maps.name: maps})

# check statistic
if isinstance(statistic, str):
statistic = [statistic]
possible_stats = ['mean', 'sum', 'std', 'var', 'min', 'max', 'median', 'count']
assert all(stat in possible_stats for stat in statistic), "All values in 'statistic' should be one of these: {0}".format(', '.join(possible_stats))
stats_dict = {var: statistic for var in maps}

# input maps
if not os.path.isdir(inputmaps):
print(f'ERROR: {inputmaps} is missing or not a directory!')
sys.exit(0)
else:
inputmaps = Path(inputmaps)
filepaths = list(inputmaps.glob('*.nc'))
if not filepaths:
print(f'ERROR: No NetCDF files found in "{inputmaps}"')
sys.exit(0)
else:
print(f'{len(filepaths)} input NetCDF files found in "{inputmaps}"')
try:
# chunks is set to auto for general purpose processing
# it could be optimized depending on input NetCDF
ds = xr.open_mfdataset(filepaths, chunks='auto', parallel=True)
except:
# for static maps
ds = xr.Dataset({file.stem.split('_')[0]: xr.open_dataset(file)['Band1'] for file in filepaths})
if 'wgs_1984' in ds:
ds = ds.drop_vars('wgs_1984')

# catchment masks
if not os.path.isdir(mask):
print(f'ERROR: {mask} is missing or not a directory!')
sys.exit(0)
# output directory
if output is None:
results = []
else:
mask = Path(mask)
maskpaths = list(mask.glob('*.nc'))
if not maskpaths:
print(f'ERROR: No NetCDF files found in "{mask}"')
sys.exit(0)
else:
maskpaths = {int(file.stem): file for file in maskpaths}
print(f'{len(maskpaths)} mask NetCDF files found in "{mask}"')
output = Path(output)
output.mkdir(parents=True, exist_ok=True)

# weighing map
if pixarea is not None:
if not os.path.isfile(pixarea):
print(f'ERROR: {pixarea} is missing!')
sys.exit(0)
else:
try:
weight = xr.open_dataset(pixarea)['Band1']
except:
print(f'ERROR: The weighing map "{pixarea}" could not be loaded')
sys.exit(0)

# define coordinates and variables of the resulting Dataset
dims = dict(ds.dims)
dims = dict(maps.dims)
dimnames = [dim.lower() for dim in dims]
if 'lat' in dimnames and 'lon' in dimnames:
x_dim, y_dim = 'lon', 'lat'
else:
x_dim, y_dim = 'x', 'y'
del dims[x_dim]
del dims[y_dim]
coords = {dim: ds[dim] for dim in dims}
stats_dict = {var: statistic for var in ds}
coords = {dim: maps[dim] for dim in dims}
variables = [f'{var}_{stat}' for var, stats in stats_dict.items() for stat in stats]

# compute statistics for each catchemnt
for ID in tqdm(maskpaths.keys(), desc='processing catchments'):

fileout = output / f'{ID:04}.nc'
if fileout.exists() & ~overwrite:
print(f'Output file {fileout} already exists. Moving forward to the next catchment')
continue
for ID in tqdm(masks.keys(), desc='processing catchments'):

if output is not None:
fileout = output / f'{ID:04}.nc'
if fileout.exists() & ~overwrite:
print(f'Output file {fileout} already exists. Moving forward to the next catchment')
continue

# create empty Dataset
coords.update({'id': [ID]})
ds_aoi = xr.Dataset({var: xr.DataArray(coords=coords, dims=coords.keys()) for var in variables})

# read mask map
try:
maskpath = maskpaths[ID]
aoi = xr.open_dataset(maskpath)['Band1']
aoi = xr.where(aoi.notnull(), 1, aoi)
except:
print(f'ERROR: The mask {maskpath} could not be read')
continue
maps_aoi = xr.Dataset({var: xr.DataArray(coords=coords, dims=coords.keys()) for var in variables})

# apply mask to the dataset
masked_ds = ds.sel({x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
masked_ds = masked_ds.compute()
aoi = masks[ID]
masked_maps = maps.sel({x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
masked_maps = masked_maps.compute()

# apply weighting by pixel area
if pixarea is not None:
# apply weighting
if weight is not None:
masked_weight = weight.sel({x_dim: aoi[x_dim], y_dim: aoi[y_dim]}).where(aoi == 1)
weighted_ds = masked_ds.weighted(masked_weight.fillna(0))
weighted_maps = masked_maps.weighted(masked_weight.fillna(0))

# compute statistics
for var, stats in stats_dict.items():
for stat in stats:
if stat in ['mean', 'sum', 'std', 'var']:
if pixarea is not None:
x = getattr(weighted_ds, stat)(dim=[x_dim, y_dim])[var]
else:
x = getattr(masked_ds, stat)(dim=[x_dim, y_dim])[var]
elif stat in ['min', 'max', 'median', 'count']:
x = getattr(masked_ds, stat)(dim=[x_dim, y_dim])[var]
ds_aoi[f'{var}_{stat}'].loc[{'id': ID}] = x

# export
ds_aoi.to_netcdf(fileout)
if (stat in ['mean', 'sum', 'std', 'var']) & (weight is not None):
x = getattr(weighted_maps, stat)(dim=[x_dim, y_dim])[var]
else:
x = getattr(masked_maps, stat)(dim=[x_dim, y_dim])[var]
maps_aoi[f'{var}_{stat}'].loc[{'id': ID}] = x

# save results
if output is None:
results.append(maps_aoi)
else:
maps_aoi.to_netcdf(fileout)

end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"Time elapsed: {elapsed_time:0.2f} seconds")


if output is None:
results = xr.concat(results, dim='id')
return results

def main(argv=sys.argv):
prog = os.path.basename(argv[0])
Expand All @@ -192,10 +255,14 @@ def main(argv=sys.argv):
parser.add_argument("-W", "--overwrite", action="store_true", help="Overwrite existing output files")

args = parser.parse_args()

catchment_statistics(args.input, args.mask, args.statistic, args.output, args.area, args.overwrite)


maps = read_inputmaps(args.input)
masks = read_masks(args.mask)
if args.area is not None:
weight = read_pixarea(args.area)
else:
weight = None
catchment_statistics(maps, masks, args.statistic, weight=weight, output=args.output, overwrite=args.overwrite)

def main_script():
sys.exit(main())
Expand Down

0 comments on commit 1697c44

Please sign in to comment.