Skip to content

Commit

Permalink
towards adding a tiling feature
Browse files Browse the repository at this point in the history
  • Loading branch information
brey committed Apr 6, 2024
1 parent b70acbf commit 98120ed
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 28 deletions.
11 changes: 9 additions & 2 deletions pyposeidon/dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Dem module
"""

# Copyright 2018 European Union
# This file is part of pyposeidon.
# Licensed under the EUPL, Version 1.2 or – as soon they will be approved by the European Commission - subsequent versions of the EUPL (the "Licence").
Expand All @@ -20,7 +21,7 @@
import pyresample
import xarray as xr

from pyposeidon.utils.fix import fix, resample
from pyposeidon.utils.fix import fix_dem, fix, resample
from pyposeidon import tools

NCORES = max(1, multiprocessing.cpu_count() - 1)
Expand Down Expand Up @@ -87,7 +88,13 @@ def __init__(self, dem_source: str, **kwargs):
self.adjust(coastline, **kwargs)

def adjust(self, coastline, **kwargs):
self.Dataset, check, flag = fix(self.Dataset, coastline, **kwargs)

tiles = kwargs.get("tiles", False)

if tiles:
self.Dataset, check = fix_dem(self.Dataset, coastline, **kwargs)
else:
self.Dataset, check, flag = fix(self.Dataset, coastline, **kwargs)

if not check:
logger.warning("Adjusting dem failed, keeping original values\n")
Expand Down
112 changes: 98 additions & 14 deletions pyposeidon/utils/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Mesh adjustment functions
"""

# Copyright 2018 European Union
# This file is part of pyposeidon.
# Licensed under the EUPL, Version 1.2 or – as soon they will be approved by the European Commission - subsequent versions of the EUPL (the "Licence").
Expand All @@ -16,6 +17,9 @@
import xarray as xr
import sys
import os
import shutil
from glob import glob
from tqdm.auto import tqdm
from pyposeidon.utils.coastfix import simplify

# logging setup
Expand All @@ -24,20 +28,93 @@
logger = logging.getLogger(__name__)


def get_tiles(data, chunks=None):

# chuck
if chunks:
ilats = data.elevation.chunk({"longitude": chunks[0], "latitude": chunks[1]}).chunks[0]
ilons = data.elevation.chunk({"longitude": chunks[0], "latitude": chunks[1]}).chunks[1]
else:
ilats = data.elevation.chunk("auto").chunks[0]
ilons = data.elevation.chunk("auto").chunks[1]

if len(ilons) == 1:
ilons = (int(ilons[0] / 2), int(ilons[0] / 2))

idx = [sum(ilons[:i]) for i in range(len(ilons) + 1)]
jdx = [sum(ilats[:i]) for i in range(len(ilats) + 1)]

blon = list(zip(idx[:-1], idx[1:]))
blat = list(zip(jdx[:-1], jdx[1:]))

perms = [(x, y) for x in blon for y in blat]

return perms


def fix_dem(dem, coastline, buffer=0.0, **kwargs):

perms = get_tiles(dem)

i = 0
check = True

if not os.path.exists("./fixtmp/"):
os.makedirs("./fixtmp/")

for (i1, i2), (j1, j2) in tqdm(perms, total=len(perms)):

lon1 = dem.longitude.data[i1:i2][0]
lon2 = dem.longitude.data[i1:i2][-1]
lat1 = dem.latitude.data[j1:j2][0]
lat2 = dem.latitude.data[j1:j2][-1]

# buffer lat/lon
blon1 = lon1 - buffer
blon2 = lon2 + buffer
blat1 = lat1 - buffer
blat2 = lat2 + buffer

# de = dem.sel(lon=slice(blon1,blon2)).sel(lat=slice(blat1,blat2))
de = dem_range(dem, blon1, blon2, blat1, blat2)

de_, check_, flag = fix(de, coastline, **kwargs)

ide = de_.sel(latitude=slice(lat1, lat2)).sel(longitude=slice(lon1, lon2))

ide.to_netcdf("./fixtmp/ide{:03d}.nc".format(i))
i += 1

check = check and check_

ifiles = glob("./fixtmp/ide*")

fdem = xr.open_mfdataset(ifiles)

##cleanup
shutil.rmtree("./fixtmp")

return fdem, check


def fix(dem, coastline, **kwargs):
# ---------------------------------------------------------------------
logger.info("adjust dem\n")
# ---------------------------------------------------------------------

ifunction = kwargs.get("resample_function", "nearest")
reset_flag = kwargs.get("reset_flag", False)

# define coastline
try:
shp = gp.GeoDataFrame.from_file(coastline)
except:
shp = gp.GeoDataFrame(coastline)

shp = simplify(shp)
sc = kwargs.get("simplify_coastlines", False)

if sc:
shp = simplify(shp)

if "ival" in dem.data_vars:
xp = dem.ilons.values
Expand Down Expand Up @@ -130,7 +207,7 @@ def fix(dem, coastline, **kwargs):
else:
dem = dem.assign(adjusted=dem.elevation)

return dem, True
return dem, True, flag

if "ival" in dem.data_vars:
df = pd.DataFrame(
Expand All @@ -154,7 +231,7 @@ def fix(dem, coastline, **kwargs):
# Add land boundaries to a shapely object
try:
lbs = []
for l in range(len(land.boundary.geoms)):
for l in tqdm(range(len(land.boundary.geoms))):
z = shapely.linearrings(land.boundary.geoms[l].coords[:])
lbs.append(z)
except:
Expand All @@ -172,7 +249,7 @@ def fix(dem, coastline, **kwargs):

try:
wl = []
for l in range(len(land.boundary.geoms)):
for l in tqdm(range(len(land.boundary.geoms))):
wl.append(tree.query(bp[l], predicate="contains").tolist())
ns = [j for i in wl for j in i]
except:
Expand All @@ -199,7 +276,7 @@ def fix(dem, coastline, **kwargs):
xw = pw.longitude.values
yw = pw.latitude.values

bw = resample(dem, xw, yw, var="elevation", wet=True, flag=flag, function=ifunction)
bw = resample(dem, xw, yw, var="elevation", wet=True, flag=flag, reset_flag=reset_flag, function=ifunction)

df.loc[pw.index, "elevation"] = bw # replace in original dataset

Expand All @@ -218,7 +295,7 @@ def fix(dem, coastline, **kwargs):
xl = pl.longitude.values
yl = pl.latitude.values

bd = resample(dem, xl, yl, var="elevation", wet=False, flag=flag, function=ifunction)
bd = resample(dem, xl, yl, var="elevation", wet=False, flag=flag, reset_flag=reset_flag, function=ifunction)

df.loc[pl.index, "elevation"] = bd # replace in original dataset

Expand Down Expand Up @@ -285,7 +362,7 @@ def fix(dem, coastline, **kwargs):
bmask[i, j] = True
cdem.adjusted.values[bmask] = 0.0 # set value

logger.info("setting land points with nan values to zero")
logger.info(f"setting {bmask.size} land points with nan values to zero")
cdem["adjusted"] = cdem.adjusted.fillna(0.0) # for land points if any (lakes, etc.)

else:
Expand Down Expand Up @@ -423,10 +500,15 @@ def check2(dataset, coastline):
return bps


def resample(dem, xw, yw, var=None, wet=True, flag=None, function="nearest"):
def resample(dem, xw, yw, var=None, wet=True, flag=None, reset_flag=False, function="nearest"):
# Define points with positive bathymetry
x, y = np.meshgrid(dem.longitude, dem.latitude)

print(f"reset_flag={reset_flag}")

if reset_flag:
flag = 0

if flag == 1:
gx = xw - 180.0
xx = x - 180.0
Expand Down Expand Up @@ -457,6 +539,8 @@ def resample(dem, xw, yw, var=None, wet=True, flag=None, function="nearest"):
orig = pyresample.geometry.SwathDefinition(lons=mx, lats=my) # original bathymetry points
targ = pyresample.geometry.SwathDefinition(lons=gx, lats=yw) # wet points

mdem = mdem.astype(float)

if function == "nearest":
bw = pyresample.kd_tree.resample_nearest(orig, mdem, targ, radius_of_influence=100000, fill_value=np.nan)

Expand Down Expand Up @@ -485,11 +569,11 @@ def dem_range(data, lon_min, lon_max, lat_min, lat_max):
lon1 = lon_max

if (lon_min < data.longitude.min()) or (lon_max > data.longitude.max()):
print("Lon must be within {} and {}".format(data.longitude.min().values, data.longitude.max().values))
print("compensating if global dataset available")
logger.info("Lon must be within {} and {}".format(data.longitude.min().values, data.longitude.max().values))
logger.info("compensating if global dataset available")

if (lat_min < data.latitude.min()) or (lat_max > data.latitude.max()):
print("Lat is within {} and {}".format(data.latitude.min().values, data.latitude.max().values))
logger.info("Lat is within {} and {}".format(data.latitude.min().values, data.latitude.max().values))

# get idx
if lon_max - lon_min == dlon1 - dlon0:
Expand All @@ -516,16 +600,16 @@ def dem_range(data, lon_min, lon_max, lat_min, lat_max):
lat_1 = min(data.latitude.size, j1 + 3)

if i0 > i1:
p1 = data.elevation.isel(longitude=slice(lon_0, data.longitude.size), latitude=slice(lat_0, lat_1))
p1 = data.isel(longitude=slice(lon_0, data.longitude.size), latitude=slice(lat_0, lat_1))

p1 = p1.assign_coords({"longitude": p1.longitude.values - 360.0})

p2 = data.elevation.isel(longitude=slice(0, lon_1), latitude=slice(lat_0, lat_1))
p2 = data.isel(longitude=slice(0, lon_1), latitude=slice(lat_0, lat_1))

dem = xr.concat([p1, p2], dim="longitude")

else:
dem = data.elevation.isel(longitude=slice(lon_0, lon_1), latitude=slice(lat_0, lat_1))
dem = data.isel(longitude=slice(lon_0, lon_1), latitude=slice(lat_0, lat_1))

if np.abs(np.mean(dem.longitude) - np.mean([lon_min, lon_max])) > 170.0:
c = np.sign(np.mean([lon_min, lon_max]))
Expand Down
92 changes: 80 additions & 12 deletions pyposeidon/utils/global_bgmesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pyresample
import xarray as xr
import os
from tqdm.auto import tqdm

import pyposeidon.boundary as pb
from pyposeidon.utils.stereo import to_3d, to_lat_lon, stereo_to_3d
Expand All @@ -17,6 +18,7 @@
)
import pyposeidon.dem as pdem
import pyposeidon.mesh as pmesh
from pyposeidon.utils.fix import dem_range, resample
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,26 +95,37 @@ def make_bgmesh_global(dfb, fpos, dem, **kwargs):
y0 = mesh.Dataset.SCHISM_hgrid_node_y.values
trii0 = mesh.Dataset.SCHISM_hgrid_face_nodes.values[:, :3]

m = mesh.Dataset
m = m.assign({"x": m["SCHISM_hgrid_node_x"], "y": m["SCHISM_hgrid_node_y"]})
# Stereo -> lat/lon
clon, clat = to_lat_lon(x0, y0)
m["x"].data = clon
m["y"].data = clat

# Select DEM
try:
dm = dem.adjusted.to_dataframe()
except:
dm = dem.elevation.to_dataframe()
# try:
# dm = dem.adjusted.to_dataframe()
# except:
# dm = dem.elevation.to_dataframe()

lon = dem.longitude.values
lat = dem.latitude.values
# lon = dem.longitude.values
# lat = dem.latitude.values

X, Y = np.meshgrid(lon, lat)
# X, Y = np.meshgrid(lon, lat)
# Stereo -> lat/lon
clon, clat = to_lat_lon(x0, y0)
# clon, clat = to_lat_lon(x0, y0)

# resample bathymetry
gdem = dm.values.flatten()
# gdem = dm.values.flatten()

orig = pyresample.geometry.SwathDefinition(lons=X.flatten(), lats=Y.flatten()) # original bathymetry points
targ = pyresample.geometry.SwathDefinition(lons=clon, lats=clat) # wet points
# orig = pyresample.geometry.SwathDefinition(lons=X.flatten(), lats=Y.flatten()) # original bathymetry points
# targ = pyresample.geometry.SwathDefinition(lons=clon, lats=clat) # wet points

bw = pyresample.kd_tree.resample_nearest(orig, gdem, targ, radius_of_influence=50000, fill_value=0)
# bw = pyresample.kd_tree.resample_nearest(orig, gdem, targ, radius_of_influence=50000, fill_value=0)

dem_on_mesh(m, dem)

bw = m.depth.data

bz = pd.DataFrame({"z": bw.flatten()})

Expand All @@ -130,3 +143,58 @@ def make_bgmesh_global(dfb, fpos, dem, **kwargs):
dfb = bk

return nodes, elems


def fillv(dem, perms, m, buffer=0.0):

for (i1, i2), (j1, j2) in tqdm(perms, total=len(perms)):

lon1 = dem.longitude.data[i1:i2][0]
lon2 = dem.longitude.data[i1:i2][-1]
lat1 = dem.latitude.data[j1:j2][0]
lat2 = dem.latitude.data[j1:j2][-1]

# buffer lat/lon
blon1 = lon1 - buffer
blon2 = lon2 + buffer
blat1 = lat1 - buffer
blat2 = lat2 + buffer

# de = dem.sel(lon=slice(blon1,blon2)).sel(lat=slice(blat1,blat2))
de = dem_range(dem, blon1, blon2, blat1, blat2)

# subset mesh
indices_of_nodes_in_bbox = np.where(
(m.y >= lat1 - buffer / 2)
& (m.y <= lat2 + buffer / 2)
& (m.x >= lon1 - buffer / 2)
& (m.x <= lon2 + buffer / 2)
)[0]

bm = m.isel(nSCHISM_hgrid_node=indices_of_nodes_in_bbox)
ids = np.argwhere(np.isnan(bm.depth.values)).flatten()

grid_x, grid_y = bm.x.data, bm.y.data

bd = resample(de, grid_x, grid_y, var="adjusted", wet=True, flag=0, function="gauss")

m["depth"].loc[dict(nSCHISM_hgrid_node=indices_of_nodes_in_bbox)] = -bd


def dem_on_mesh(mesh, dem):

ilats = dem.elevation.chunk("auto").chunks[0]
ilons = dem.elevation.chunk("auto").chunks[1]

if len(ilons) == 1:
ilons = (int(ilons[0] / 2), int(ilons[0] / 2))

idx = [sum(ilons[:i]) for i in range(len(ilons) + 1)]
jdx = [sum(ilats[:i]) for i in range(len(ilats) + 1)]

blon = list(zip(idx[:-1], idx[1:]))
blat = list(zip(jdx[:-1], jdx[1:]))

perms = [(x, y) for x in blon for y in blat]

fillv(dem, perms, mesh, buffer=5)

0 comments on commit 98120ed

Please sign in to comment.