Skip to content

Commit

Permalink
Update fit_gev dask options and edit test_eva docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
stellema committed Nov 16, 2023
1 parent 75ca7cf commit 59f7371
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
10 changes: 5 additions & 5 deletions unseen/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ def check_gev_fit(data, theta, time_dim="time"):
"""

def _goodness_of_fit(data, theta):
assert len(theta) in [3, 5]
if len(theta) == 3:
# Stationary parameters
shape, loc, scale = theta
test_data = data
else:
elif len(theta) == 5:
# Non-stationary parameters (test middle 20% of data).
# N.B. the non stationary parameters vary with the data distribution as a
# function of the covariate, so we test a subset of data at a specific point.
Expand All @@ -193,10 +194,9 @@ def _goodness_of_fit(data, theta):
data,
theta,
input_core_dims=[[time_dim], ["theta"]],
output_core_dims=[[]],
vectorize=True,
dask="allowed",
dask_gufunc_kwargs=dict(output_dtypes="float64", meta=(float)),
dask="parallelized",
dask_gufunc_kwargs=dict(meta=(np.ndarray(1, float),)),
)
return pvalue

Expand Down Expand Up @@ -432,7 +432,7 @@ def fit(data, **kwargs):
input_core_dims=[[time_dim]],
output_core_dims=[["theta"]],
vectorize=True,
dask="allowed",
dask="parallelized",
kwargs=kwargs,
output_dtypes=["float64"],
dask_gufunc_kwargs=dict(output_sizes={"theta": n}),
Expand Down
19 changes: 8 additions & 11 deletions unseen/tests/test_eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

rtol = 0.3 # relative tolerance
alpha = 0.05
np.random.seed(1)


def example_da_gev_1d():
"""An example 1D GEV DataArray and its distribution parameters."""
time = np.arange(
"2000-01-01", "2005-01-01", np.timedelta64(1, "D"), dtype="datetime64[ns]"
"2000-01-01", "2003-01-01", np.timedelta64(1, "D"), dtype="datetime64[ns]"
)

# Generate shape, location and scale parameters.
np.random.seed(0)
shape = np.random.uniform()
loc = np.random.uniform(-10, 10)
scale = np.random.uniform(0.1, 10)
Expand All @@ -38,12 +38,13 @@ def example_da_gev_1d_dask():
def example_da_gev_3d():
"""An example 3D GEV DataArray and its distribution parameters."""
time = np.arange(
"2000-01-01", "2005-01-01", np.timedelta64(1, "D"), dtype="datetime64[ns]"
"2000-01-01", "2003-01-01", np.timedelta64(1, "D"), dtype="datetime64[ns]"
)
lats = np.arange(0, 2)
lons = np.arange(0, 2)
shape = (len(lats), len(lons))

np.random.seed(0)
c = np.random.rand(*shape)
loc = np.random.rand(*shape) + np.random.randint(-10, 10, shape)
scale = np.random.rand(*shape)
Expand Down Expand Up @@ -73,10 +74,6 @@ def add_example_gev_trend(data):
return data + trend


data, theta_i = example_da_gev_3d()
theta = fit_gev(data, stationary=True, check_fit=False)


def test_fit_gev_1d():
"""Run stationary fit using 1D array & check results."""
data, theta_i = example_da_gev_1d()
Expand Down Expand Up @@ -110,7 +107,7 @@ def test_fit_gev_3d_dask():


def test_fit_ns_gev_1d():
"""Run stationary fit using 1D array & check results."""
"""Run non-stationary fit using 1D array & check results."""
data, _ = example_da_gev_1d()
data = add_example_gev_trend(data)

Expand All @@ -124,7 +121,7 @@ def test_fit_ns_gev_1d():


def test_fit_ns_gev_1d_dask():
"""Run stationary fit using 1D dask array & check results."""
"""Run non-stationary fit using 1D dask array & check results."""
data, _ = example_da_gev_1d_dask()

# Add a positive linear trend.
Expand All @@ -138,7 +135,7 @@ def test_fit_ns_gev_1d_dask():


def test_fit_ns_gev_3d():
"""Run stationary fit using 3D array & check results."""
"""Run non-stationary fit using 3D array & check results."""
data, _ = example_da_gev_3d()

# Add a positive linear trend.
Expand All @@ -152,7 +149,7 @@ def test_fit_ns_gev_3d():


def test_fit_ns_gev_3d_dask():
"""Run stationary fit using 3D dask array & check results."""
"""Run non-stationary fit using 3D dask array & check results."""
data, _ = example_da_gev_3d_dask()

# Add a positive linear trend.
Expand Down

0 comments on commit 59f7371

Please sign in to comment.