From 98120eda8ef524f5a008929c979ec99f7dcc8f06 Mon Sep 17 00:00:00 2001 From: George Breyiannis Date: Sat, 6 Apr 2024 17:00:01 +0200 Subject: [PATCH] towards adding a tiling feature --- pyposeidon/dem.py | 11 ++- pyposeidon/utils/fix.py | 112 ++++++++++++++++++++++++++---- pyposeidon/utils/global_bgmesh.py | 92 ++++++++++++++++++++---- 3 files changed, 187 insertions(+), 28 deletions(-) diff --git a/pyposeidon/dem.py b/pyposeidon/dem.py index fcd47595..683febc1 100644 --- a/pyposeidon/dem.py +++ b/pyposeidon/dem.py @@ -2,6 +2,7 @@ Dem module """ + # Copyright 2018 European Union # This file is part of pyposeidon. # Licensed under the EUPL, Version 1.2 or – as soon they will be approved by the European Commission - subsequent versions of the EUPL (the "Licence"). @@ -20,7 +21,7 @@ import pyresample import xarray as xr -from pyposeidon.utils.fix import fix, resample +from pyposeidon.utils.fix import fix_dem, fix, resample from pyposeidon import tools NCORES = max(1, multiprocessing.cpu_count() - 1) @@ -87,7 +88,13 @@ def __init__(self, dem_source: str, **kwargs): self.adjust(coastline, **kwargs) def adjust(self, coastline, **kwargs): - self.Dataset, check, flag = fix(self.Dataset, coastline, **kwargs) + + tiles = kwargs.get("tiles", False) + + if tiles: + self.Dataset, check = fix_dem(self.Dataset, coastline, **kwargs) + else: + self.Dataset, check, flag = fix(self.Dataset, coastline, **kwargs) if not check: logger.warning("Adjusting dem failed, keeping original values\n") diff --git a/pyposeidon/utils/fix.py b/pyposeidon/utils/fix.py index 792f38b4..f781f2e9 100644 --- a/pyposeidon/utils/fix.py +++ b/pyposeidon/utils/fix.py @@ -2,6 +2,7 @@ Mesh adjustment functions """ + # Copyright 2018 European Union # This file is part of pyposeidon. # Licensed under the EUPL, Version 1.2 or – as soon they will be approved by the European Commission - subsequent versions of the EUPL (the "Licence"). @@ -16,6 +17,9 @@ import xarray as xr import sys import os +import shutil +from glob import glob +from tqdm.auto import tqdm from pyposeidon.utils.coastfix import simplify # logging setup @@ -24,12 +28,82 @@ logger = logging.getLogger(__name__) +def get_tiles(data, chunks=None): + + # chuck + if chunks: + ilats = data.elevation.chunk({"longitude": chunks[0], "latitude": chunks[1]}).chunks[0] + ilons = data.elevation.chunk({"longitude": chunks[0], "latitude": chunks[1]}).chunks[1] + else: + ilats = data.elevation.chunk("auto").chunks[0] + ilons = data.elevation.chunk("auto").chunks[1] + + if len(ilons) == 1: + ilons = (int(ilons[0] / 2), int(ilons[0] / 2)) + + idx = [sum(ilons[:i]) for i in range(len(ilons) + 1)] + jdx = [sum(ilats[:i]) for i in range(len(ilats) + 1)] + + blon = list(zip(idx[:-1], idx[1:])) + blat = list(zip(jdx[:-1], jdx[1:])) + + perms = [(x, y) for x in blon for y in blat] + + return perms + + +def fix_dem(dem, coastline, buffer=0.0, **kwargs): + + perms = get_tiles(dem) + + i = 0 + check = True + + if not os.path.exists("./fixtmp/"): + os.makedirs("./fixtmp/") + + for (i1, i2), (j1, j2) in tqdm(perms, total=len(perms)): + + lon1 = dem.longitude.data[i1:i2][0] + lon2 = dem.longitude.data[i1:i2][-1] + lat1 = dem.latitude.data[j1:j2][0] + lat2 = dem.latitude.data[j1:j2][-1] + + # buffer lat/lon + blon1 = lon1 - buffer + blon2 = lon2 + buffer + blat1 = lat1 - buffer + blat2 = lat2 + buffer + + # de = dem.sel(lon=slice(blon1,blon2)).sel(lat=slice(blat1,blat2)) + de = dem_range(dem, blon1, blon2, blat1, blat2) + + de_, check_, flag = fix(de, coastline, **kwargs) + + ide = de_.sel(latitude=slice(lat1, lat2)).sel(longitude=slice(lon1, lon2)) + + ide.to_netcdf("./fixtmp/ide{:03d}.nc".format(i)) + i += 1 + + check = check and check_ + + ifiles = glob("./fixtmp/ide*") + + fdem = xr.open_mfdataset(ifiles) + + ##cleanup + shutil.rmtree("./fixtmp") + + return fdem, check + + def fix(dem, coastline, **kwargs): # --------------------------------------------------------------------- logger.info("adjust dem\n") # --------------------------------------------------------------------- ifunction = kwargs.get("resample_function", "nearest") + reset_flag = kwargs.get("reset_flag", False) # define coastline try: @@ -37,7 +111,10 @@ def fix(dem, coastline, **kwargs): except: shp = gp.GeoDataFrame(coastline) - shp = simplify(shp) + sc = kwargs.get("simplify_coastlines", False) + + if sc: + shp = simplify(shp) if "ival" in dem.data_vars: xp = dem.ilons.values @@ -130,7 +207,7 @@ def fix(dem, coastline, **kwargs): else: dem = dem.assign(adjusted=dem.elevation) - return dem, True + return dem, True, flag if "ival" in dem.data_vars: df = pd.DataFrame( @@ -154,7 +231,7 @@ def fix(dem, coastline, **kwargs): # Add land boundaries to a shapely object try: lbs = [] - for l in range(len(land.boundary.geoms)): + for l in tqdm(range(len(land.boundary.geoms))): z = shapely.linearrings(land.boundary.geoms[l].coords[:]) lbs.append(z) except: @@ -172,7 +249,7 @@ def fix(dem, coastline, **kwargs): try: wl = [] - for l in range(len(land.boundary.geoms)): + for l in tqdm(range(len(land.boundary.geoms))): wl.append(tree.query(bp[l], predicate="contains").tolist()) ns = [j for i in wl for j in i] except: @@ -199,7 +276,7 @@ def fix(dem, coastline, **kwargs): xw = pw.longitude.values yw = pw.latitude.values - bw = resample(dem, xw, yw, var="elevation", wet=True, flag=flag, function=ifunction) + bw = resample(dem, xw, yw, var="elevation", wet=True, flag=flag, reset_flag=reset_flag, function=ifunction) df.loc[pw.index, "elevation"] = bw # replace in original dataset @@ -218,7 +295,7 @@ def fix(dem, coastline, **kwargs): xl = pl.longitude.values yl = pl.latitude.values - bd = resample(dem, xl, yl, var="elevation", wet=False, flag=flag, function=ifunction) + bd = resample(dem, xl, yl, var="elevation", wet=False, flag=flag, reset_flag=reset_flag, function=ifunction) df.loc[pl.index, "elevation"] = bd # replace in original dataset @@ -285,7 +362,7 @@ def fix(dem, coastline, **kwargs): bmask[i, j] = True cdem.adjusted.values[bmask] = 0.0 # set value - logger.info("setting land points with nan values to zero") + logger.info(f"setting {bmask.size} land points with nan values to zero") cdem["adjusted"] = cdem.adjusted.fillna(0.0) # for land points if any (lakes, etc.) else: @@ -423,10 +500,15 @@ def check2(dataset, coastline): return bps -def resample(dem, xw, yw, var=None, wet=True, flag=None, function="nearest"): +def resample(dem, xw, yw, var=None, wet=True, flag=None, reset_flag=False, function="nearest"): # Define points with positive bathymetry x, y = np.meshgrid(dem.longitude, dem.latitude) + print(f"reset_flag={reset_flag}") + + if reset_flag: + flag = 0 + if flag == 1: gx = xw - 180.0 xx = x - 180.0 @@ -457,6 +539,8 @@ def resample(dem, xw, yw, var=None, wet=True, flag=None, function="nearest"): orig = pyresample.geometry.SwathDefinition(lons=mx, lats=my) # original bathymetry points targ = pyresample.geometry.SwathDefinition(lons=gx, lats=yw) # wet points + mdem = mdem.astype(float) + if function == "nearest": bw = pyresample.kd_tree.resample_nearest(orig, mdem, targ, radius_of_influence=100000, fill_value=np.nan) @@ -485,11 +569,11 @@ def dem_range(data, lon_min, lon_max, lat_min, lat_max): lon1 = lon_max if (lon_min < data.longitude.min()) or (lon_max > data.longitude.max()): - print("Lon must be within {} and {}".format(data.longitude.min().values, data.longitude.max().values)) - print("compensating if global dataset available") + logger.info("Lon must be within {} and {}".format(data.longitude.min().values, data.longitude.max().values)) + logger.info("compensating if global dataset available") if (lat_min < data.latitude.min()) or (lat_max > data.latitude.max()): - print("Lat is within {} and {}".format(data.latitude.min().values, data.latitude.max().values)) + logger.info("Lat is within {} and {}".format(data.latitude.min().values, data.latitude.max().values)) # get idx if lon_max - lon_min == dlon1 - dlon0: @@ -516,16 +600,16 @@ def dem_range(data, lon_min, lon_max, lat_min, lat_max): lat_1 = min(data.latitude.size, j1 + 3) if i0 > i1: - p1 = data.elevation.isel(longitude=slice(lon_0, data.longitude.size), latitude=slice(lat_0, lat_1)) + p1 = data.isel(longitude=slice(lon_0, data.longitude.size), latitude=slice(lat_0, lat_1)) p1 = p1.assign_coords({"longitude": p1.longitude.values - 360.0}) - p2 = data.elevation.isel(longitude=slice(0, lon_1), latitude=slice(lat_0, lat_1)) + p2 = data.isel(longitude=slice(0, lon_1), latitude=slice(lat_0, lat_1)) dem = xr.concat([p1, p2], dim="longitude") else: - dem = data.elevation.isel(longitude=slice(lon_0, lon_1), latitude=slice(lat_0, lat_1)) + dem = data.isel(longitude=slice(lon_0, lon_1), latitude=slice(lat_0, lat_1)) if np.abs(np.mean(dem.longitude) - np.mean([lon_min, lon_max])) > 170.0: c = np.sign(np.mean([lon_min, lon_max])) diff --git a/pyposeidon/utils/global_bgmesh.py b/pyposeidon/utils/global_bgmesh.py index 155081bb..f5467df5 100644 --- a/pyposeidon/utils/global_bgmesh.py +++ b/pyposeidon/utils/global_bgmesh.py @@ -3,6 +3,7 @@ import pyresample import xarray as xr import os +from tqdm.auto import tqdm import pyposeidon.boundary as pb from pyposeidon.utils.stereo import to_3d, to_lat_lon, stereo_to_3d @@ -17,6 +18,7 @@ ) import pyposeidon.dem as pdem import pyposeidon.mesh as pmesh +from pyposeidon.utils.fix import dem_range, resample import logging logger = logging.getLogger(__name__) @@ -93,26 +95,37 @@ def make_bgmesh_global(dfb, fpos, dem, **kwargs): y0 = mesh.Dataset.SCHISM_hgrid_node_y.values trii0 = mesh.Dataset.SCHISM_hgrid_face_nodes.values[:, :3] + m = mesh.Dataset + m = m.assign({"x": m["SCHISM_hgrid_node_x"], "y": m["SCHISM_hgrid_node_y"]}) + # Stereo -> lat/lon + clon, clat = to_lat_lon(x0, y0) + m["x"].data = clon + m["y"].data = clat + # Select DEM - try: - dm = dem.adjusted.to_dataframe() - except: - dm = dem.elevation.to_dataframe() + # try: + # dm = dem.adjusted.to_dataframe() + # except: + # dm = dem.elevation.to_dataframe() - lon = dem.longitude.values - lat = dem.latitude.values + # lon = dem.longitude.values + # lat = dem.latitude.values - X, Y = np.meshgrid(lon, lat) + # X, Y = np.meshgrid(lon, lat) # Stereo -> lat/lon - clon, clat = to_lat_lon(x0, y0) + # clon, clat = to_lat_lon(x0, y0) # resample bathymetry - gdem = dm.values.flatten() + # gdem = dm.values.flatten() - orig = pyresample.geometry.SwathDefinition(lons=X.flatten(), lats=Y.flatten()) # original bathymetry points - targ = pyresample.geometry.SwathDefinition(lons=clon, lats=clat) # wet points + # orig = pyresample.geometry.SwathDefinition(lons=X.flatten(), lats=Y.flatten()) # original bathymetry points + # targ = pyresample.geometry.SwathDefinition(lons=clon, lats=clat) # wet points - bw = pyresample.kd_tree.resample_nearest(orig, gdem, targ, radius_of_influence=50000, fill_value=0) + # bw = pyresample.kd_tree.resample_nearest(orig, gdem, targ, radius_of_influence=50000, fill_value=0) + + dem_on_mesh(m, dem) + + bw = m.depth.data bz = pd.DataFrame({"z": bw.flatten()}) @@ -130,3 +143,58 @@ def make_bgmesh_global(dfb, fpos, dem, **kwargs): dfb = bk return nodes, elems + + +def fillv(dem, perms, m, buffer=0.0): + + for (i1, i2), (j1, j2) in tqdm(perms, total=len(perms)): + + lon1 = dem.longitude.data[i1:i2][0] + lon2 = dem.longitude.data[i1:i2][-1] + lat1 = dem.latitude.data[j1:j2][0] + lat2 = dem.latitude.data[j1:j2][-1] + + # buffer lat/lon + blon1 = lon1 - buffer + blon2 = lon2 + buffer + blat1 = lat1 - buffer + blat2 = lat2 + buffer + + # de = dem.sel(lon=slice(blon1,blon2)).sel(lat=slice(blat1,blat2)) + de = dem_range(dem, blon1, blon2, blat1, blat2) + + # subset mesh + indices_of_nodes_in_bbox = np.where( + (m.y >= lat1 - buffer / 2) + & (m.y <= lat2 + buffer / 2) + & (m.x >= lon1 - buffer / 2) + & (m.x <= lon2 + buffer / 2) + )[0] + + bm = m.isel(nSCHISM_hgrid_node=indices_of_nodes_in_bbox) + ids = np.argwhere(np.isnan(bm.depth.values)).flatten() + + grid_x, grid_y = bm.x.data, bm.y.data + + bd = resample(de, grid_x, grid_y, var="adjusted", wet=True, flag=0, function="gauss") + + m["depth"].loc[dict(nSCHISM_hgrid_node=indices_of_nodes_in_bbox)] = -bd + + +def dem_on_mesh(mesh, dem): + + ilats = dem.elevation.chunk("auto").chunks[0] + ilons = dem.elevation.chunk("auto").chunks[1] + + if len(ilons) == 1: + ilons = (int(ilons[0] / 2), int(ilons[0] / 2)) + + idx = [sum(ilons[:i]) for i in range(len(ilons) + 1)] + jdx = [sum(ilats[:i]) for i in range(len(ilats) + 1)] + + blon = list(zip(idx[:-1], idx[1:])) + blat = list(zip(jdx[:-1], jdx[1:])) + + perms = [(x, y) for x in blon for y in blat] + + fillv(dem, perms, mesh, buffer=5)