Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

calc_geodist_exact: allow passing DataArray #299

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ New Features

By `Mathias Hauser`_.

- Allow passing `xr.DataArray` to ``gaspari_cohn`` (`#298 <https://github.com/MESMER-group/mesmer/issues/298>`__).
- Allow passing `xr.DataArray` to ``gaspari_cohn`` (`#298 <https://github.com/MESMER-group/mesmer/pull/298>`__).
By `Mathias Hauser`_.

- Allow passing `xr.DataArray` to ``calc_geodist_exact`` (`#299 <https://github.com/MESMER-group/mesmer/pull/299>`__).
By `Zeb Nicholls`_ and `Mathias Hauser`_.


Breaking changes
Expand Down
54 changes: 43 additions & 11 deletions mesmer/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
import pyproj
import xarray as xr

from .utils import create_equal_dim_names


def gaspari_cohn(r):
"""smooth, exponentially decaying Gaspari-Cohn correlation function

Parameters
----------
r : xr.DataArray, np.array
r : xr.DataArray, np.ndarray
Values for which to calculate the value of the Gaspari-Cohn correlation function
(e.g. normalised geographical distances)

Returns
-------
out : xr.DataArray, , np.array
out : xr.DataArray, , np.ndarray
Gaspari-Cohn correlation function

Notes
Expand Down Expand Up @@ -89,39 +91,69 @@ def _gaspari_cohn_np(r):
return out


def calc_geodist_exact(lon, lat):
def calc_geodist_exact(lon, lat, equal_dim_suffixes=("_i", "_j")):
"""exact great circle distance based on WSG 84

Parameters
----------
lon : array-like
lon : xr.DataArray, np.ndarray
1D array of longitudes
lat : array-like
lat : xr.DataArray, np.ndarray
1D array of latitudes
equal_dim_suffixes : tuple of str, default: ("_i", "_j")
Suffixes to add to the the name of ``dim`` for the geodist array (xr.DataArray
cannot have two dimensions with the same name).

Returns
-------
geodist : np.array
geodist : xr.DataArray, np.ndarray
2D array of great circle distances.
"""

# TODO: allow Dataset (e.g. using cf_xarray)
if isinstance(lon, xr.Dataset) or isinstance(lat, xr.Dataset):
mathause marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError("Dataset is not supported, please pass a DataArray")

# handle numpy arrays
if not isinstance(lon, xr.DataArray) or not isinstance(lat, xr.DataArray):
return _calc_geodist_exact(np.asarray(lon), np.asarray(lat))

# TODO: allow differently named lon and lat dims?
if lon.dims != lat.dims:
raise AssertionError(
f"lon and lat have different dims: {lon.dims} vs. {lat.dims}. Expected "
"equally named dimensions from a stacked array"
)

geodist = _calc_geodist_exact(lon.values, lat.values)

(dim,) = lon.dims
dims = create_equal_dim_names(dim, equal_dim_suffixes)

# TODO: assign coords?
geodist = xr.DataArray(geodist, dims=dims)

return geodist


def _calc_geodist_exact(lon, lat):

# ensure correct shape
lon, lat = np.asarray(lon), np.asarray(lat)
if lon.shape != lat.shape or lon.ndim != 1:
raise ValueError("lon and lat need to be 1D arrays of the same shape")
raise ValueError("lon and lat must be 1D arrays of the same shape")

geod = pyproj.Geod(ellps="WGS84")

n_points = len(lon)
n_points = lon.size

geodist = np.zeros([n_points, n_points])

# calculate only the upper right half of the triangle
for i in range(n_points):

# need to duplicate gridpoint (required by geod.inv)
lt = np.tile(lat[i], n_points - (i + 1))
ln = np.tile(lon[i], n_points - (i + 1))
lt = np.repeat(lat[i : i + 1], n_points - (i + 1))
ln = np.repeat(lon[i : i + 1], n_points - (i + 1))
Comment on lines +155 to +156
Copy link
Member Author

@mathause mathause Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat is faster than tile. Also using an array (lat[i : i + 1]) is faster than a scalar (lat[i]).


geodist[i, i + 1 :] = geod.inv(ln, lt, lon[i + 1 :], lat[i + 1 :])[2]

Expand Down
6 changes: 3 additions & 3 deletions mesmer/stats/localized_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def find_localized_empirical_covariance(
Dimension along which to calculate the covariance.
k_folds : int
Number of folds to use for cross validation.
equal_dim_suffixes : tuple of str
Suffixes to add to the the name of ``dim`` for the covariance array (xr.DataArray cannot have two
dimensions with the same name).
equal_dim_suffixes : tuple of str, default: ("_i", "_j")
Suffixes to add to the the name of ``dim`` for the covariance array
(xr.DataArray cannot have two dimensions with the same name).

Returns
-------
Expand Down
98 changes: 97 additions & 1 deletion tests/unit/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import xarray as xr

from mesmer.core.computation import gaspari_cohn
from mesmer.core.computation import calc_geodist_exact, gaspari_cohn


def test_gaspari_cohn_error():
Expand Down Expand Up @@ -51,3 +51,99 @@ def test_gaspari_cohn_np():
# make sure shape is conserved
values = np.arange(9).reshape(3, 3)
assert gaspari_cohn(values).shape == (3, 3)


def test_calc_geodist_dataset_error():

ds = xr.Dataset()
da = xr.DataArray()

with pytest.raises(TypeError, match="Dataset is not supported"):
calc_geodist_exact(ds, ds)

with pytest.raises(TypeError, match="Dataset is not supported"):
calc_geodist_exact(ds, da)

with pytest.raises(TypeError, match="Dataset is not supported"):
calc_geodist_exact(da, ds)


def test_calc_geodist_dataarray_equal_dims_required():

lon = xr.DataArray([0], dims="lon")
lat = xr.DataArray([0], dims="lat")

with pytest.raises(AssertionError, match="lon and lat have different dims"):
calc_geodist_exact(lon, lat)


@pytest.mark.parametrize("as_dataarray", [True, False])
def test_calc_geodist_not_same_shape_error(as_dataarray):

lon, lat = [0, 0], [0]

if as_dataarray:
lon, lat = xr.DataArray(lon), xr.DataArray(lat)

with pytest.raises(ValueError, match="lon and lat must be 1D arrays"):
calc_geodist_exact(lon, lat)


@pytest.mark.parametrize("as_dataarray", [True, False])
def test_calc_geodist_not_1D_error(as_dataarray):

lon = lat = [[0, 0]]

if as_dataarray:
lon, lat = xr.DataArray(lon), xr.DataArray(lat)

with pytest.raises(ValueError, match=".*of the same shape"):
calc_geodist_exact(lon, lat)


@pytest.mark.parametrize("lon", [[0, 0], [0, 360], [1, 361], [180, -180]])
@pytest.mark.parametrize("as_dataarray", [True, False])
def test_calc_geodist_exact_equal(lon, as_dataarray):
"""test points with distance 0"""

expected = np.array([[0, 0], [0, 0]])

lat = [0, 0]

if as_dataarray:
lon = xr.DataArray(lon)

result = calc_geodist_exact(lon, lat)
np.testing.assert_equal(result, expected)
# when passing only one DataArray it's also returned as np.array
assert isinstance(result, np.ndarray)


@pytest.mark.parametrize("as_dataarray", [True, False])
def test_calc_geodist_exact(as_dataarray):
"""test some random points"""

lon = [-180, 0, 3]
lat = [0, 0, 5]

if as_dataarray:
lon = xr.DataArray(lon, dims="gp", coords={"lon": ("gp", lon)})
lat = xr.DataArray(lat, dims="gp", coords={"lat": ("gp", lat)})

result = calc_geodist_exact(lon, lat)
expected = np.array(
[
[0.0, 20003.93145863, 19366.51816487],
[20003.93145863, 0.0, 645.70051988],
[19366.51816487, 645.70051988, 0.0],
]
)

if as_dataarray:

expected = xr.DataArray(expected, dims=("gp_i", "gp_j"))
xr.testing.assert_allclose(expected, result)

else:

np.testing.assert_allclose(result, expected)
45 changes: 0 additions & 45 deletions tests/unit/test_phi_gc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np
import pytest

from mesmer.core.computation import calc_geodist_exact
from mesmer.io import load_phi_gc, load_regs_ls_wgt_lon_lat


Expand Down Expand Up @@ -57,46 +55,3 @@ def test_phi_gc_end_to_end(tmp_path):
]
)
np.testing.assert_allclose(expected, actual[1000], rtol=1e-5)


def test_calc_geodist_exact_shape():

msg = "lon and lat need to be 1D arrays of the same shape"

# not the same shape
with pytest.raises(ValueError, match=msg):
calc_geodist_exact([0, 0], [0])

# not 1D
with pytest.raises(ValueError, match=msg):
calc_geodist_exact([[0, 0]], [[0, 0]])


def test_calc_geodist_exact_equal():
"""test points with distance 0"""

expected = np.array([[0, 0], [0, 0]])

lat = [0, 0]
lons = [[0, 0], [0, 360], [1, 361], [180, -180]]

for lon in lons:
result = calc_geodist_exact(lon, lat)
np.testing.assert_equal(result, expected)

result = calc_geodist_exact(lon, lat)
np.testing.assert_equal(result, expected)


def test_calc_geodist_exact():
"""test some random points"""
result = calc_geodist_exact([-180, 0, 3], [0, 0, 5])
expected = np.array(
[
[0.0, 20003.93145863, 19366.51816487],
[20003.93145863, 0.0, 645.70051988],
[19366.51816487, 645.70051988, 0.0],
]
)

np.testing.assert_allclose(result, expected)