From 713ed38565088acf52ccdab405dfd2f23e567764 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 10 Aug 2022 11:33:27 -0700 Subject: [PATCH] Fast graph generation in fv3fit (#1991) The current graph generation code in fv3fit is very slow, and requires an external dependency. This PR uses indexing-based logic on a cubed sphere instead to greatly reduce the time needed to generate the graph, and remove the external dependency. Based on and may be merged into #1986 . Significant internal changes: - Greatly reduced build_graph execution time and removed dependency on geopy Requirement changes: - Removed dependency of fv3fit on geopy --- external/fv3fit/fv3fit/pytorch/graph/graph.py | 11 +-- .../fv3fit/pytorch/graph/graph_builder.py | 77 +++++++++++-------- external/fv3fit/setup.py | 2 +- external/fv3fit/tests/training/test_graph.py | 21 ++++- external/fv3gfs-fortran | 2 +- 5 files changed, 67 insertions(+), 46 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/graph/graph.py b/external/fv3fit/fv3fit/pytorch/graph/graph.py index ebecc2d180..82aa77366a 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/graph.py +++ b/external/fv3fit/fv3fit/pytorch/graph/graph.py @@ -83,19 +83,10 @@ def train_graph_model( train_batches: tf.data.Dataset, validation_batches: Optional[tf.data.Dataset], ): - """ Train a graph network. Args: - train_batches: training data, as a dataset of Mapping[str, tf.Tensor] - validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor] - build_model: the function which produces the pytorch model - from input and output samples. The models returned must take a list of - tensors as input and return a list of tensors as output. - input_variables: names of inputs for the pytorch model - output_variables: names of outputs for the pytorch model - n_epoch: number of epochs hyperparameters: configuration for training train_batches: training data, as a dataset of Mapping[str, tf.Tensor] validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor] @@ -156,7 +147,7 @@ def build_model(graph, graph_network): graph: configuration for building graph graph_network: configuration of the graph network """ - graph_data = build_graph(graph) + graph_data = build_graph(graph.nx_tile) g = dgl.graph(graph_data) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") train_model = GraphNetwork( diff --git a/external/fv3fit/fv3fit/pytorch/graph/graph_builder.py b/external/fv3fit/fv3fit/pytorch/graph/graph_builder.py index 3582654689..bae2504d83 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/graph_builder.py +++ b/external/fv3fit/fv3fit/pytorch/graph/graph_builder.py @@ -1,46 +1,57 @@ -import torch +from typing import Tuple import numpy as np -import geopy.distance -from vcm.catalog import catalog import dataclasses +from fv3fit.keras._models.shared.halos import append_halos +import xarray as xr @dataclasses.dataclass class GraphConfig: """ Attributes: - neighbor: number of nearest neighbor grids - coarsen: 1 if the full model resolution is used, othewise data will be coarsen - resolution: Model resolution to load the corresponding lat and lon. + nx_tile: number of horizontal grid points on each tile of the cubed sphere """ - neighbor: int = 10 - coarsen: int = 8 - resolution: str = "grid/c48" + # TODO: this should not be configurable, it should be determined + # by the shape of the input data + nx_tile: int = 48 -def build_graph(config: GraphConfig) -> tuple: - nodes = [] - edges = [] - - grid = catalog[config.resolution].read() - lat = grid.lat.load() - lon = grid.lon.load() - - lat = lat[:, :: config.coarsen, :: config.coarsen].values.flatten() - lon = lon[:, :: config.coarsen, :: config.coarsen].values.flatten() +def build_graph(nx_tile: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Returns two 1D arrays containing the start and end points of edges of the graph. - for i in range(0, len(lat)): - distance = np.zeros([len(lat)]) - coords_1 = (lat[i], lon[i]) - for j in range(0, len(lat)): - coords_2 = (lat[j], lon[j]) - distance[j] = geopy.distance.geodesic(coords_1, coords_2).km - destination_grids = sorted(range(len(distance)), key=lambda i: distance[i])[ - : config.neighbor - ] - nodes.append(np.repeat(i, config.neighbor, axis=0)) - edges.append(destination_grids) - nodes = torch.tensor(nodes).flatten() - edges = torch.tensor(edges).flatten() - return (nodes, edges) + Args: + nx_tile: number of horizontal grid points on each tile of the cubed sphere + """ + n_tile, nx, ny = 6, nx_tile, nx_tile + n_points = n_tile * nx * ny + ds = xr.Dataset( + data_vars={ + "index": xr.DataArray( + np.arange(n_points).reshape((n_tile, nx, ny, 1)), + dims=["tile", "x", "y", "z"], + ) + } + ) + ds = append_halos(ds, n_halo=1).squeeze("z") + index = ds["index"].values + total_edges = n_points * 5 + out = np.empty((total_edges, 2), dtype=int) + # left connections + out[:n_points, 0] = index[:, 1:-1, 1:-1].flatten() + out[:n_points, 1] = index[:, :-2, 1:-1].flatten() + # right connections + out[n_points : 2 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() + out[n_points : 2 * n_points, 1] = index[:, 2:, 1:-1].flatten() + # up connections + out[2 * n_points : 3 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() + out[2 * n_points : 3 * n_points, 1] = index[:, 1:-1, 2:].flatten() + # down connections + out[3 * n_points : 4 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() + out[3 * n_points : 4 * n_points, 1] = index[:, 1:-1, :-2].flatten() + # self-connections + out[4 * n_points : 5 * n_points, 0] = index[:, 1:-1, 1:-1].flatten() + out[4 * n_points : 5 * n_points, 1] = index[:, 1:-1, 1:-1].flatten() + + return out[:, 0], out[:, 1] diff --git a/external/fv3fit/setup.py b/external/fv3fit/setup.py index 7d327b5ca1..d3333e9d74 100644 --- a/external/fv3fit/setup.py +++ b/external/fv3fit/setup.py @@ -23,12 +23,12 @@ "torch", "torchvision", "dgl", - "geopy", "tensorflow-datasets", ] setup_requirements = [] + test_requirements = ["pytest"] setup( diff --git a/external/fv3fit/tests/training/test_graph.py b/external/fv3fit/tests/training/test_graph.py index d4ccdfc2e0..5bd85235bd 100644 --- a/external/fv3fit/tests/training/test_graph.py +++ b/external/fv3fit/tests/training/test_graph.py @@ -7,6 +7,7 @@ from fv3fit.pytorch.predict import PytorchModel import pytest from fv3fit.pytorch.graph.graph import GraphHyperparameters +from fv3fit.pytorch.graph.graph_builder import build_graph, GraphConfig GENERAL_TRAINING_TYPES = [ "graph", @@ -38,7 +39,9 @@ def train_identity_model(hyperparameters=None): np.random.seed(2) sample_test = get_uniform_sample_func(size=(grid, nz), low=low, high=high) test_dataset = xr.Dataset({"a": sample_test()}) - hyperparameters = GraphHyperparameters(input_variable, output_variables) + hyperparameters = GraphHyperparameters( + input_variable, output_variables, graph=GraphConfig(nx_tile=6) + ) train = fv3fit.get_training_function("graph") model = train(hyperparameters, train_dataset, val_tfdataset) return TrainingResult(model, output_variables, test_dataset, hyperparameters) @@ -115,3 +118,19 @@ def sample_func(): ) return sample_func + + +def test_graph_builder(): + graph = build_graph(2) + edges = list(zip(graph[0], graph[1])) + assert (0, 1) in edges # right edge + assert (0, 2) in edges # top edge + assert (0, 0) in edges # self edge + # Note due to the [tile, x, y] direction convention for index + # ordering, the node index orders are transposed compared to the + # MPI rank ordering used for cubed sphere decomposition in fv3gfs. + # If the order were [tile, y, x] then the node indices would be + # transposed in (x, y) on each tile. + assert (0, 21) in edges # down edge + assert (0, 19) in edges # left edge + assert len(edges) == 5 * 6 * 4 diff --git a/external/fv3gfs-fortran b/external/fv3gfs-fortran index cf1eb54ab4..6a0dfd1f02 160000 --- a/external/fv3gfs-fortran +++ b/external/fv3gfs-fortran @@ -1 +1 @@ -Subproject commit cf1eb54ab439059dc7796e3b0b16404a5dda4966 +Subproject commit 6a0dfd1f02ce4db3033d9e8b09f9d2911d9b45e5