Skip to content

Commit

Permalink
Fast graph generation in fv3fit (#1991)
Browse files Browse the repository at this point in the history
The current graph generation code in fv3fit is very slow, and requires an external dependency. This PR uses indexing-based logic on a cubed sphere instead to greatly reduce the time needed to generate the graph, and remove the external dependency.

Based on and may be merged into #1986 .

Significant internal changes:
- Greatly reduced build_graph execution time and removed dependency on geopy

Requirement changes:
- Removed dependency of fv3fit on geopy
  • Loading branch information
mcgibbon authored Aug 10, 2022
1 parent 5385445 commit 713ed38
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 46 deletions.
11 changes: 1 addition & 10 deletions external/fv3fit/fv3fit/pytorch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,10 @@ def train_graph_model(
train_batches: tf.data.Dataset,
validation_batches: Optional[tf.data.Dataset],
):

"""
Train a graph network.
Args:
train_batches: training data, as a dataset of Mapping[str, tf.Tensor]
validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor]
build_model: the function which produces the pytorch model
from input and output samples. The models returned must take a list of
tensors as input and return a list of tensors as output.
input_variables: names of inputs for the pytorch model
output_variables: names of outputs for the pytorch model
n_epoch: number of epochs
hyperparameters: configuration for training
train_batches: training data, as a dataset of Mapping[str, tf.Tensor]
validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor]
Expand Down Expand Up @@ -156,7 +147,7 @@ def build_model(graph, graph_network):
graph: configuration for building graph
graph_network: configuration of the graph network
"""
graph_data = build_graph(graph)
graph_data = build_graph(graph.nx_tile)
g = dgl.graph(graph_data)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_model = GraphNetwork(
Expand Down
77 changes: 44 additions & 33 deletions external/fv3fit/fv3fit/pytorch/graph/graph_builder.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,57 @@
import torch
from typing import Tuple
import numpy as np
import geopy.distance
from vcm.catalog import catalog
import dataclasses
from fv3fit.keras._models.shared.halos import append_halos
import xarray as xr


@dataclasses.dataclass
class GraphConfig:
"""
Attributes:
neighbor: number of nearest neighbor grids
coarsen: 1 if the full model resolution is used, othewise data will be coarsen
resolution: Model resolution to load the corresponding lat and lon.
nx_tile: number of horizontal grid points on each tile of the cubed sphere
"""

neighbor: int = 10
coarsen: int = 8
resolution: str = "grid/c48"
# TODO: this should not be configurable, it should be determined
# by the shape of the input data
nx_tile: int = 48


def build_graph(config: GraphConfig) -> tuple:
nodes = []
edges = []

grid = catalog[config.resolution].read()
lat = grid.lat.load()
lon = grid.lon.load()

lat = lat[:, :: config.coarsen, :: config.coarsen].values.flatten()
lon = lon[:, :: config.coarsen, :: config.coarsen].values.flatten()
def build_graph(nx_tile: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Returns two 1D arrays containing the start and end points of edges of the graph.
for i in range(0, len(lat)):
distance = np.zeros([len(lat)])
coords_1 = (lat[i], lon[i])
for j in range(0, len(lat)):
coords_2 = (lat[j], lon[j])
distance[j] = geopy.distance.geodesic(coords_1, coords_2).km
destination_grids = sorted(range(len(distance)), key=lambda i: distance[i])[
: config.neighbor
]
nodes.append(np.repeat(i, config.neighbor, axis=0))
edges.append(destination_grids)
nodes = torch.tensor(nodes).flatten()
edges = torch.tensor(edges).flatten()
return (nodes, edges)
Args:
nx_tile: number of horizontal grid points on each tile of the cubed sphere
"""
n_tile, nx, ny = 6, nx_tile, nx_tile
n_points = n_tile * nx * ny
ds = xr.Dataset(
data_vars={
"index": xr.DataArray(
np.arange(n_points).reshape((n_tile, nx, ny, 1)),
dims=["tile", "x", "y", "z"],
)
}
)
ds = append_halos(ds, n_halo=1).squeeze("z")
index = ds["index"].values
total_edges = n_points * 5
out = np.empty((total_edges, 2), dtype=int)
# left connections
out[:n_points, 0] = index[:, 1:-1, 1:-1].flatten()
out[:n_points, 1] = index[:, :-2, 1:-1].flatten()
# right connections
out[n_points : 2 * n_points, 0] = index[:, 1:-1, 1:-1].flatten()
out[n_points : 2 * n_points, 1] = index[:, 2:, 1:-1].flatten()
# up connections
out[2 * n_points : 3 * n_points, 0] = index[:, 1:-1, 1:-1].flatten()
out[2 * n_points : 3 * n_points, 1] = index[:, 1:-1, 2:].flatten()
# down connections
out[3 * n_points : 4 * n_points, 0] = index[:, 1:-1, 1:-1].flatten()
out[3 * n_points : 4 * n_points, 1] = index[:, 1:-1, :-2].flatten()
# self-connections
out[4 * n_points : 5 * n_points, 0] = index[:, 1:-1, 1:-1].flatten()
out[4 * n_points : 5 * n_points, 1] = index[:, 1:-1, 1:-1].flatten()

return out[:, 0], out[:, 1]
2 changes: 1 addition & 1 deletion external/fv3fit/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
"torch",
"torchvision",
"dgl",
"geopy",
"tensorflow-datasets",
]

setup_requirements = []


test_requirements = ["pytest"]

setup(
Expand Down
21 changes: 20 additions & 1 deletion external/fv3fit/tests/training/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fv3fit.pytorch.predict import PytorchModel
import pytest
from fv3fit.pytorch.graph.graph import GraphHyperparameters
from fv3fit.pytorch.graph.graph_builder import build_graph, GraphConfig

GENERAL_TRAINING_TYPES = [
"graph",
Expand Down Expand Up @@ -38,7 +39,9 @@ def train_identity_model(hyperparameters=None):
np.random.seed(2)
sample_test = get_uniform_sample_func(size=(grid, nz), low=low, high=high)
test_dataset = xr.Dataset({"a": sample_test()})
hyperparameters = GraphHyperparameters(input_variable, output_variables)
hyperparameters = GraphHyperparameters(
input_variable, output_variables, graph=GraphConfig(nx_tile=6)
)
train = fv3fit.get_training_function("graph")
model = train(hyperparameters, train_dataset, val_tfdataset)
return TrainingResult(model, output_variables, test_dataset, hyperparameters)
Expand Down Expand Up @@ -115,3 +118,19 @@ def sample_func():
)

return sample_func


def test_graph_builder():
graph = build_graph(2)
edges = list(zip(graph[0], graph[1]))
assert (0, 1) in edges # right edge
assert (0, 2) in edges # top edge
assert (0, 0) in edges # self edge
# Note due to the [tile, x, y] direction convention for index
# ordering, the node index orders are transposed compared to the
# MPI rank ordering used for cubed sphere decomposition in fv3gfs.
# If the order were [tile, y, x] then the node indices would be
# transposed in (x, y) on each tile.
assert (0, 21) in edges # down edge
assert (0, 19) in edges # left edge
assert len(edges) == 5 * 6 * 4

0 comments on commit 713ed38

Please sign in to comment.