Skip to content

Commit

Permalink
Add file caching to the roughness step.
Browse files Browse the repository at this point in the history
  • Loading branch information
rosepearson committed Nov 30, 2023
1 parent f9b940e commit 5cb0d69
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 107 deletions.
196 changes: 104 additions & 92 deletions src/geofabrics/dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,99 @@ def _add_lidar_no_chunking(
"_add_lidar_no_chunking must be instantiated in the " "child class"
)

def _load_dem(self, filename: pathlib.Path, chunk_size: int):
"""Load in and replace the DEM with a previously cached version."""
dem = rioxarray.rioxarray.open_rasterio(
filename,
masked=True,
parse_coordinates=True,
chunks={"x": chunk_size, "y": chunk_size},
)
dem = dem.squeeze("band", drop=True)
self._write_netcdf_conventions_in_place(dem, self.catchment_geometry.crs)
self._dem = dem

if "no_values_mask" in self._dem.keys():
self._dem["no_values_mask"] = self._dem.no_values_mask.astype(bool)
if "data_source" in self._dem.keys():
self._dem["data_source"] = self._dem.data_source.astype(
geometry.RASTER_TYPE
)
if "lidar_source" in self._dem.keys():
self._dem["lidar_source"] = self._dem.lidar_source.astype(
geometry.RASTER_TYPE
)
if "z" in self._dem.keys():
self._dem["z"] = self._dem.z.astype(geometry.RASTER_TYPE)

def save_dem(
self,
filename: pathlib.Path,
dem: xarray.Dataset,
):
"""Save the DEM to a netCDF file."""

# Save the file
try:
dem.to_netcdf(
filename,
format="NETCDF4",
engine="netcdf4",
)
# Close the DEM
dem.close()
except (Exception, KeyboardInterrupt) as caught_exception:
pathlib.Path(filename).unlink()
logging.info(
f"Caught error {caught_exception} and deleting"
"partially created netCDF output "
f"{filename} before re-raising error."
)
raise caught_exception

def _save_and_load_dem(
self,
filename: pathlib.Path,
chunk_size: int,
no_values_mask: bool,
buffer_cells: int = None,
):
"""Update the saved file cache for the DEM as a netCDF file. The bool
no_data_layer may optionally be included."""

# Get the DEM from the property call
dem = self._dem
# Create mask if specified
if no_values_mask:
if self.catchment_geometry.land_and_foreshore.area.sum() > 0:
no_value_mask = (
dem.z.rolling(
dim={"x": buffer_cells * 2 + 1, "y": buffer_cells * 2 + 1},
min_periods=1,
center=True,
)
.count()
.isnull()
)
no_value_mask &= (
xarray.ones_like(self._dem.z)
.rio.clip(
self.catchment_geometry.land_and_foreshore.geometry, drop=False
)
.notnull()
) # Awkward as clip of a bool xarray doesn't work as expected
else:
no_value_mask = xarray.zeros_like(self._dem.z)
dem["no_values_mask"] = no_value_mask
dem.no_values_mask.rio.write_crs(
self.catchment_geometry.crs["horizontal"], inplace=True
)

# Save the DEM with the no_values_layer
self.save_dem(filename=filename, dem=dem)
# Load in the temporarily saved DEM
self._load_dem(filename=filename, chunk_size=chunk_size)


class RawDem(LidarBase):
"""A class to manage the creation of a 'raw' DEM from LiDAR tiles, and/or a
Expand Down Expand Up @@ -1427,10 +1520,11 @@ def add_lidar(

# Save a cached copy of DEM to temporary memory cache
logging.info("In dem.add_lidar - write out temp raw DEM to netCDF")
self._save_and_load_dem_with_no_values_mask(
self._save_and_load_dem(
filename=self.temp_folder / "raw_lidar.nc",
buffer_cells=buffer_cells,
chunk_size=chunk_size,
no_values_mask=True,
)

def _add_tiled_lidar_chunked(
Expand Down Expand Up @@ -1915,11 +2009,12 @@ def dask_interpolation(y, x):
logging.info(
"In dem.add_coarse_dems - write out temp raw DEM to netCDF"
)
self._save_and_load_dem_with_no_values_mask(
self._save_and_load_dem(
filename=self.temp_folder
/ f"raw_dem_{coarse_dem_path.stem}.nc",
buffer_cells=buffer_cells,
chunk_size=chunk_size,
no_values_mask=True,
)
logging.info(
f"In dem.add_coarse_dems - remove previous cached file {previous_cached_file}"
Expand All @@ -1929,96 +2024,6 @@ def dask_interpolation(y, x):
self.temp_folder / f"raw_dem_{coarse_dem_path.stem}.nc"
)

def _load_dem(self, filename: pathlib.Path, chunk_size: int):
"""Load in and replace the DEM with a previously cached version."""
dem = rioxarray.rioxarray.open_rasterio(
filename,
masked=True,
parse_coordinates=True,
chunks={"x": chunk_size, "y": chunk_size},
)
dem = dem.squeeze("band", drop=True)
self._write_netcdf_conventions_in_place(dem, self.catchment_geometry.crs)
self._dem = dem

if "no_values_mask" in self._dem.keys():
self._dem["no_values_mask"] = self._dem.no_values_mask.astype(bool)
if "data_source" in self._dem.keys():
self._dem["data_source"] = self._dem.data_source.astype(
geometry.RASTER_TYPE
)
if "lidar_source" in self._dem.keys():
self._dem["lidar_source"] = self._dem.lidar_source.astype(
geometry.RASTER_TYPE
)
if "z" in self._dem.keys():
self._dem["z"] = self._dem.z.astype(geometry.RASTER_TYPE)

def save_dem(
self,
filename: pathlib.Path,
dem: xarray.Dataset,
):
"""Save the DEM to a netCDF file."""

# Save the file
try:
dem.to_netcdf(
filename,
format="NETCDF4",
engine="netcdf4",
)
# Close the DEM
dem.close()
except (Exception, KeyboardInterrupt) as caught_exception:
pathlib.Path(filename).unlink()
logging.info(
f"Caught error {caught_exception} and deleting"
"partially created netCDF output "
f"{filename} before re-raising error."
)
raise caught_exception

def _save_and_load_dem_with_no_values_mask(
self,
filename: pathlib.Path,
buffer_cells: int,
chunk_size: int,
):
"""Update the saved file cache for the DEM as a netCDF file. The no_data_layer of bol values may
optionally be included."""

# Get the DEM from the property call
dem = self._dem
if self.catchment_geometry.land_and_foreshore.area.sum() > 0:
no_value_mask = (
dem.z.rolling(
dim={"x": buffer_cells * 2 + 1, "y": buffer_cells * 2 + 1},
min_periods=1,
center=True,
)
.count()
.isnull()
)
no_value_mask &= (
xarray.ones_like(self._dem.z)
.rio.clip(
self.catchment_geometry.land_and_foreshore.geometry, drop=False
)
.notnull()
) # Awkward as clip of a bool xarray doesn't work as expected
else:
no_value_mask = xarray.zeros_like(self._dem.z)
dem["no_values_mask"] = no_value_mask
dem.no_values_mask.rio.write_crs(
self.catchment_geometry.crs["horizontal"], inplace=True
)

# Save the DEM with the no_values_layer
self.save_dem(filename=filename, dem=dem)
# Load in the temporarily saved DEM
self._load_dem(filename=filename, chunk_size=chunk_size)


class RoughnessDem(LidarBase):
"""A class to add a roughness (zo) layer to a hydrologically conditioned DEM.
Expand Down Expand Up @@ -2047,6 +2052,7 @@ def __init__(
self,
catchment_geometry: geometry.CatchmentGeometry,
hydrological_dem_path: typing.Union[str, pathlib.Path],
temp_folder: pathlib.Path,
interpolation_method: str,
default_values: dict,
drop_offshore_lidar: dict,
Expand Down Expand Up @@ -2085,6 +2091,7 @@ def __init__(
self.catchment_geometry.catchment.geometry, drop=True
)

self.temp_folder = temp_folder
self.interpolation_method = interpolation_method
self.default_values = default_values
self.drop_offshore_lidar = drop_offshore_lidar
Expand Down Expand Up @@ -2160,6 +2167,11 @@ def add_lidar(
chunk_size=chunk_size,
metadata=metadata,
)
self._save_and_load_dem(
filename=self.temp_folder / "raw_lidar_zo.nc",
chunk_size=chunk_size,
no_values_mask=False,
)
# Set roughness where water
self._dem["zo"] = self._dem.zo.where(
self._dem.data_source != self.SOURCE_CLASSIFICATION["ocean bathymetry"],
Expand Down
42 changes: 27 additions & 15 deletions src/geofabrics/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,20 @@ def run(self):
" Either take mean or drop points with height above limit."
)

# Create folder for caching raw DEM files during DEM generation
temp_folder = (
self.get_instruction_path("subfolder")
/ "temp"
/ f"{self.get_resolution()}m_results"
)
logging.info(
"In processor.DemGenerator - create folder for writing temporarily"
f" cached netCDF files in {temp_folder}"
)
if temp_folder.exists():
shutil.rmtree(temp_folder)
temp_folder.mkdir(parents=True, exist_ok=True)

# Setup Dask cluster and client
cluster_kwargs = {
"n_workers": self.get_processing_instructions("number_of_cores"),
Expand All @@ -1290,6 +1304,7 @@ def run(self):
self.roughness_dem = dem.RoughnessDem(
catchment_geometry=self.catchment_geometry,
hydrological_dem_path=self.get_instruction_path("result_dem"),
temp_folder=temp_folder,
elevation_range=self.get_instruction_general("elevation_range"),
interpolation_method=self.get_instruction_general(
key="interpolation", subkey="no_data"
Expand All @@ -1310,21 +1325,18 @@ def run(self):
) # Note must be called after all others if it is to be complete

# save results
try:
self.roughness_dem.dem.to_netcdf(
self.get_instruction_path("result_geofabric"),
format="NETCDF4",
engine="netcdf4",
)
except (Exception, KeyboardInterrupt) as caught_exception:
pathlib.Path(self.get_instruction_path("result_geofabric")).unlink()
logging.info(
f"Caught error {caught_exception} and deleting"
"partially created netCDF output "
f"{self.get_instruction_path('result_geofabric')}"
" before re-raising error."
)
raise caught_exception
logging.info("In processor.RoughnessLengthGenerator - write out "
"the raw DEM to netCDF")
self.roughness_dem.save_dem(
filename=self.get_instruction_path("result_geofabric"),
dem=self.roughness_dem.dem,
)
logging.info(
"In processor.RoughnessLengthGenerator - clean folder for "
f"writing temporarily cached netCDF files in {temp_folder}"
)
shutil.rmtree(temp_folder)

if self.debug:
# Record the parameter used during execution - append to existing
with open(
Expand Down

0 comments on commit 5cb0d69

Please sign in to comment.