-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fast graph generation in fv3fit (#1991)
The current graph generation code in fv3fit is very slow, and requires an external dependency. This PR uses indexing-based logic on a cubed sphere instead to greatly reduce the time needed to generate the graph, and remove the external dependency. Based on and may be merged into #1986 . Significant internal changes: - Greatly reduced build_graph execution time and removed dependency on geopy Requirement changes: - Removed dependency of fv3fit on geopy
- Loading branch information
Showing
5 changed files
with
67 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,57 @@ | ||
import torch | ||
from typing import Tuple | ||
import numpy as np | ||
import geopy.distance | ||
from vcm.catalog import catalog | ||
import dataclasses | ||
from fv3fit.keras._models.shared.halos import append_halos | ||
import xarray as xr | ||
|
||
|
||
@dataclasses.dataclass | ||
class GraphConfig: | ||
""" | ||
Attributes: | ||
neighbor: number of nearest neighbor grids | ||
coarsen: 1 if the full model resolution is used, othewise data will be coarsen | ||
resolution: Model resolution to load the corresponding lat and lon. | ||
nx_tile: number of horizontal grid points on each tile of the cubed sphere | ||
""" | ||
|
||
neighbor: int = 10 | ||
coarsen: int = 8 | ||
resolution: str = "grid/c48" | ||
# TODO: this should not be configurable, it should be determined | ||
# by the shape of the input data | ||
nx_tile: int = 48 | ||
|
||
|
||
def build_graph(config: GraphConfig) -> tuple: | ||
nodes = [] | ||
edges = [] | ||
|
||
grid = catalog[config.resolution].read() | ||
lat = grid.lat.load() | ||
lon = grid.lon.load() | ||
|
||
lat = lat[:, :: config.coarsen, :: config.coarsen].values.flatten() | ||
lon = lon[:, :: config.coarsen, :: config.coarsen].values.flatten() | ||
def build_graph(nx_tile: int) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Returns two 1D arrays containing the start and end points of edges of the graph. | ||
for i in range(0, len(lat)): | ||
distance = np.zeros([len(lat)]) | ||
coords_1 = (lat[i], lon[i]) | ||
for j in range(0, len(lat)): | ||
coords_2 = (lat[j], lon[j]) | ||
distance[j] = geopy.distance.geodesic(coords_1, coords_2).km | ||
destination_grids = sorted(range(len(distance)), key=lambda i: distance[i])[ | ||
: config.neighbor | ||
] | ||
nodes.append(np.repeat(i, config.neighbor, axis=0)) | ||
edges.append(destination_grids) | ||
nodes = torch.tensor(nodes).flatten() | ||
edges = torch.tensor(edges).flatten() | ||
return (nodes, edges) | ||
Args: | ||
nx_tile: number of horizontal grid points on each tile of the cubed sphere | ||
""" | ||
n_tile, nx, ny = 6, nx_tile, nx_tile | ||
n_points = n_tile * nx * ny | ||
ds = xr.Dataset( | ||
data_vars={ | ||
"index": xr.DataArray( | ||
np.arange(n_points).reshape((n_tile, nx, ny, 1)), | ||
dims=["tile", "x", "y", "z"], | ||
) | ||
} | ||
) | ||
ds = append_halos(ds, n_halo=1).squeeze("z") | ||
index = ds["index"].values | ||
total_edges = n_points * 5 | ||
out = np.empty((total_edges, 2), dtype=int) | ||
# left connections | ||
out[:n_points, 0] = index[:, 1:-1, 1:-1].flatten() | ||
out[:n_points, 1] = index[:, :-2, 1:-1].flatten() | ||
# right connections | ||
out[n_points : 2 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() | ||
out[n_points : 2 * n_points, 1] = index[:, 2:, 1:-1].flatten() | ||
# up connections | ||
out[2 * n_points : 3 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() | ||
out[2 * n_points : 3 * n_points, 1] = index[:, 1:-1, 2:].flatten() | ||
# down connections | ||
out[3 * n_points : 4 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() | ||
out[3 * n_points : 4 * n_points, 1] = index[:, 1:-1, :-2].flatten() | ||
# self-connections | ||
out[4 * n_points : 5 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() | ||
out[4 * n_points : 5 * n_points, 1] = index[:, 1:-1, 1:-1].flatten() | ||
|
||
return out[:, 0], out[:, 1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule fv3gfs-fortran
updated
24 files