Skip to content

Commit

Permalink
⌨️ type compatibility and more test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
asross committed Apr 26, 2023
1 parent 9d18168 commit 0d4daef
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 62 deletions.
44 changes: 26 additions & 18 deletions eqn_disco/hybrid_symbolic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Hybrid symbolic module."""
from typing import Optional, List, Dict, Any, Tuple
from typing import Optional, List, Dict, Any, Tuple, Union

import gplearn.genetic
from gplearn.functions import _Function as Function
Expand Down Expand Up @@ -94,7 +94,13 @@ def run_gplearn_iteration(
"""
base_features = ["q", "u", "v"] if base_features is None else base_features
base_functions = ["add", "mul"] if base_functions is None else base_functions
spatial_functions = (
["ddx", "ddy", "laplacian", "advected"]
if spatial_functions is None
else spatial_functions
)
spatial_functions = make_custom_gplearn_functions(data_set, spatial_functions)
function_set = base_functions + spatial_functions # type: ignore

# Flatten the input and target data
inputs = np.array(
Expand Down Expand Up @@ -122,7 +128,7 @@ def run_gplearn_iteration(
# and for relatively few generations (again for performance)
regressor = gplearn.genetic.SymbolicRegressor(
feature_names=base_features,
function_set=base_functions + spatial_functions, # use our custom ops
function_set=function_set,
**gplearn_kwargs,
)

Expand Down Expand Up @@ -201,18 +207,18 @@ def predict(self, model_or_dataset: ModelLike) -> Dict[str, ArrayLike]:
# to data that may or may not have extra batch dimensions
for idx, lr_model in enumerate(self.models):
data_indices = [slice(None) for _ in inputs.shape]
data_indices[-3] = idx
data_indices[-3] = idx # type: ignore
layer_inputs = inputs[tuple(data_indices)]
coef_indices = [np.newaxis for _ in layer_inputs.shape]
coef_indices[0] = slice(None)
coef_indices[0] = slice(None) # type: ignore
layer_coefs = lr_model.coef_[tuple(coef_indices)]
layer_preds = (layer_inputs * layer_coefs).sum(axis=0)
preds.append(layer_preds)

preds = np.stack(preds, axis=-3)
res = {}
res[self.target] = preds
return res
return res # type: ignore

@classmethod
def fit(
Expand Down Expand Up @@ -253,22 +259,24 @@ def fit(
return cls(models, inputs, target)


def _each_layer(data_set: xr.Dataset) -> List[xr.Dataset]:
"""Return a list of datasets for each vertical layer in `data_set`."""
if "lev" in data_set:
return [data_set.isel(lev=z) for z in range(len(data_set.lev))]
def _each_layer(
data: Union[xr.Dataset, xr.DataArray]
) -> List[Union[xr.Dataset, xr.DataArray]]:
"""Return a list representing the `data` broken out by vertical layer."""
if "lev" in data.dims:
return [data.isel(lev=z) for z in range(len(data.lev))]

return [data_set]
return [data]


def corr(spatial_data_a: xr.DataArray, spatial_data_b: xr.DataArray) -> float:
def _corr(spatial_data_a: ArrayLike, spatial_data_b: ArrayLike) -> float:
"""Return the Pearson correlation between two spatial data arrays.
Parameters
----------
a : xarray.DataArray
a : Union[xarray.DataArray, numpy.ndarray]
First spatial data array
b : xarray.DataArray
b : Union[xarray.DataArray, numpy.ndarray]
Second spatial data array
Returns
Expand All @@ -278,8 +286,8 @@ def corr(spatial_data_a: xr.DataArray, spatial_data_b: xr.DataArray) -> float:
"""
return pearsonr(
np.array(spatial_data_a.data).ravel(),
np.array(spatial_data_b.data).ravel(),
ensure_numpy(spatial_data_a).ravel(),
ensure_numpy(spatial_data_b).ravel(),
)[0]


Expand Down Expand Up @@ -326,17 +334,17 @@ def hybrid_symbolic_regression( # pylint: disable=too-many-locals
try:
for i in range(max_iters):
for data_set_layer, residual_layer in zip(
_each_layer(data_set), _each_layer(residual)
_each_layer(data_set), _each_layer(residual) # type: ignore
):
symbolic_regressor = run_gplearn_iteration(
data_set_layer, target=residual_layer, **kw
data_set_layer, target=residual_layer.data, **kw # type: ignore
)
new_term = str(
symbolic_regressor._program # pylint: disable=protected-access
)
new_vals = extract(new_term)
# Prevent spurious duplicates, e.g. ddx(q) and ddx(add(1,q))
if not any(corr(new_vals, v) > 0.99 for v in vals):
if not any(_corr(new_vals, v) > 0.99 for v in vals):
terms.append(new_term)
vals.append(new_vals)
hybrid_regressor = LinearSymbolicRegression.fit(data_set, terms, target)
Expand Down
89 changes: 67 additions & 22 deletions eqn_disco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ArrayLike = Union[np.ndarray, xr.DataArray]
Numeric = Union[ArrayLike, int, float]
StringOrNumeric = Union[str, Numeric]
ParameterizationSuperclass = pyqg.Parameterization if IMPORTED_PYQG else object


def ensure_numpy(array: ArrayLike) -> np.ndarray:
Expand All @@ -38,12 +37,13 @@ def ensure_numpy(array: ArrayLike) -> np.ndarray:
return array


class Parameterization(ParameterizationSuperclass):
class Parameterization(pyqg.Parameterization if IMPORTED_PYQG else object): # type: ignore
"""Helper class for defining parameterizations.
This extends the normal pyqg parameterization framework to handle
This extends the normal pyqg.Parameterization framework to handle
prediction of either subgrid forcings or fluxes, as well as to apply to
either pyqg.Models orxarray.Datasets.
either pyqg.Models or xarray.Datasets. Can also be used without pyqg, though
in a more limited fashion.
"""

Expand All @@ -55,12 +55,12 @@ def targets(self) -> List[str]:
-------
List[str]
List of parameterization targets returned by this parameterization.
Valid options are "q_forcing_total", "q_subgrid_forcing",
"u_subgrid_forcing", "v_subgrid_forcing", "uq_subgrid_flux",
"vq_subgrid_flux", "uu_subgrid_flux", "vv_subgrid_flux", and
"uv_subgrid_flux". See the dataset description notebook or the
paper for more details on the meanings of these target fields and
how they're used.
If using within pyqg, valid options are "q_forcing_total",
"q_subgrid_forcing", "u_subgrid_forcing", "v_subgrid_forcing",
"uq_subgrid_flux", "vq_subgrid_flux", "uu_subgrid_flux",
"vv_subgrid_flux", and "uv_subgrid_flux". See the dataset
description notebook or the paper for more details on the meanings
of these target fields and how they're used.
"""
raise NotImplementedError
Expand Down Expand Up @@ -99,14 +99,14 @@ def parameterization_type(self) -> str:
Indication of whether the parameterization targets PV or velocity.
"""
assert IMPORTED_PYQG, "pyqg must be installed to use this method"

if any(q in self.targets[0] for q in ["q_forcing", "q_subgrid"]):
return "q_parameterization"

return "uv_parameterization"

def __call__(
self, model: ModelLike
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]:
def __call__(self, model: ModelLike) -> Union[np.ndarray, Tuple[np.ndarray, ...]]:
"""Invoke the parameterization in the format required by pyqg.
Parameters
Expand All @@ -124,6 +124,7 @@ def __call__(
type as the model's PV variable.
"""
assert IMPORTED_PYQG, "pyqg must be installed to use this method"

def _ensure_array(array: ArrayLike) -> np.ndarray:
"""Convert an array-like to numpy with model-compatible dtype."""
Expand Down Expand Up @@ -289,7 +290,9 @@ class FeatureExtractor:
"""

def __call__(self, feature_or_features: Union[str, List[str]], flat: bool = False):
def __call__(
self, feature_or_features: Union[str, List[str]], flat: bool = False
) -> np.ndarray:
"""Extract the given feature/features from underlying dataset/ model.
Parameters
Expand All @@ -307,13 +310,13 @@ def __call__(self, feature_or_features: Union[str, List[str]], flat: bool = Fals
"""
if isinstance(feature_or_features, str):
res = ensure_numpy(self.extract_feature(feature_or_features))
res = ensure_numpy(self.extract_feature(feature_or_features)) # type: ignore
if flat:
res = res.reshape(-1)

else:
res = np.array(
[ensure_numpy(self.extract_feature(f)) for f in feature_or_features]
[ensure_numpy(self.extract_feature(f)) for f in feature_or_features] # type: ignore
)
if flat:
res = res.reshape(len(feature_or_features), -1).T
Expand Down Expand Up @@ -402,9 +405,9 @@ def spatial_dims(self) -> Tuple[str, ...]:
return self.example_realspace_input.dims

@property
def spectral_dims(self) -> Tuple[str, ...]:
def spectral_dims(self) -> List[str]:
"""Names of spatial dimensions in spectral space."""
return [dict(y="l", x="k").get(d, d) for d in self.spatial_dims]
return [{"y": "l", "x": "k"}.get(d, d) for d in self.spatial_dims]

def ifft(self, spectral_array: ArrayLike) -> ArrayLike:
"""Compute the inverse FFT of ``x``.
Expand Down Expand Up @@ -435,17 +438,13 @@ def _is_real(self, arr: ArrayLike) -> bool:
def _real(self, feature: StringOrNumeric) -> ArrayLike:
"""Load and convert a feature to real space, if necessary."""
arr = self[feature]
if isinstance(arr, float):
return arr
if self._is_real(arr):
return arr
return self.ifft(arr)

def _compl(self, feature: StringOrNumeric) -> ArrayLike:
"""Load and convert a feature to spectral space, if necessary."""
arr = self[feature]
if isinstance(arr, float):
return arr
if self._is_real(arr):
return self.fft(arr)
return arr
Expand Down Expand Up @@ -684,3 +683,49 @@ def energy_budget_figure(models, skip=0):
axis.set_ylim(-vmax, vmax)
plt.tight_layout()
return fig


def example_non_pyqg_data_set(
grid_length: int = 8, num_samples: int = 20
) -> xr.Dataset:
"""Create a simple xarray dataset for testing the library without `pyqg`.
This dataset has a single variable called `inputs` with `x`, `y`, and
`batch` coordinates. It also has spectral coordinates `k` and `l` defined.
It can be used in various methods of the library without needing to invoke
`pyqg`.
Parameters
----------
grid_length : int
The length of the grid in each dimension.
num_samples : int
The number of samples in the dataset.
Returns
-------
xr.Dataset
The dataset.
"""
grid = np.linspace(0, 1, grid_length)
inputs = np.random.normal(size=(num_samples, grid_length, grid_length))
vertical_wavenumbers = (
2
* np.pi
* np.append(np.arange(0.0, grid_length / 2), np.arange(-grid_length / 2, 0.0))
)
horizontal_wavenumbers = 2 * np.pi * np.arange(0.0, grid_length / 2 + 1)

return xr.Dataset(
data_vars={
"inputs": (("batch", "y", "x"), inputs),
},
coords={
"x": grid,
"y": grid,
"l": vertical_wavenumbers,
"k": horizontal_wavenumbers,
"batch": np.arange(num_samples),
},
)
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
disable_error_code = attr-defined, import, valid-type, var-annotated
56 changes: 34 additions & 22 deletions tests/test_hybrid_symbolic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pyqg
import numpy as np
import xarray as xr
from eqn_disco.hybrid_symbolic import LinearSymbolicRegression, run_gplearn_iteration
from eqn_disco.utils import FeatureExtractor
from eqn_disco.hybrid_symbolic import LinearSymbolicRegression, run_gplearn_iteration, hybrid_symbolic_regression
from eqn_disco.utils import FeatureExtractor, example_non_pyqg_data_set

def test_linear_symbolic():
model = pyqg.QGModel()
Expand All @@ -28,27 +28,9 @@ def test_linear_symbolic():
model2 = pyqg.QGModel(parameterization=parameterization)
model2._step_forward()


def test_run_gplearn_iteration():
grid_length = 8
num_samples = 20
grid = np.linspace(0, 1, grid_length)
x, y = np.meshgrid(grid, grid)
inputs = np.random.normal(size=(num_samples, grid_length, grid_length))
l = 2 * np.pi * np.append(np.arange(0., grid_length/2), np.arange(-grid_length/2, 0.))
k = 2 * np.pi * np.arange(0., grid_length/2 + 1)

data_set = xr.Dataset(
data_vars=dict(
inputs=(('batch', 'y', 'x'), inputs),
),
coords=dict(
x=grid,
y=grid,
l=l,
k=k,
batch=np.arange(num_samples)
)
)
data_set = example_non_pyqg_data_set()

extractor = FeatureExtractor(data_set, example_realspace_input="inputs")

Expand All @@ -69,3 +51,33 @@ def test_run_gplearn_iteration():
result = str(regressor._program)

assert result == 'ddx(inputs)'


def test_hybrid_symbolic_regression():
data_set = example_non_pyqg_data_set()

extractor = FeatureExtractor(data_set, example_realspace_input="inputs")

data_set['target'] = extractor.extract_feature('ddx(inputs)')

terms, hybrid_regressors = hybrid_symbolic_regression(
data_set,
target='target',
max_iters=2,
verbose=False,
base_features=['inputs'],
base_functions=[],
spatial_functions=['ddx'],
population_size=100,
generations=10,
metric='mse',
random_state=42
)

assert terms == ['ddx(inputs)']

regressor = hybrid_regressors[-1]

assert len(regressor.models) == 1

np.testing.assert_allclose(regressor.models[0].coef_[0], 1.0)
Loading

0 comments on commit 0d4daef

Please sign in to comment.