From eccb24e5bfcbcef8bca207d2087706a68b9312f0 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Fri, 25 Oct 2024 11:40:48 +0200 Subject: [PATCH 01/15] Fixing bug when warning is raised if release_times are NaN --- parcels/particleset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index e6c819c69..c4e10d4e5 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1078,7 +1078,7 @@ def execute( outputdt = output_file.outputdt if output_file else np.inf if isinstance(outputdt, timedelta): outputdt = outputdt.total_seconds() - if outputdt is not None: + if np.isfinite(outputdt) is not None: _warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"]) if isinstance(callbackdt, timedelta): @@ -1235,7 +1235,7 @@ def execute( def _warn_outputdt_release_desync(outputdt: float, release_times: Iterable[float]): """Gives the user a warning if the release time isn't a multiple of outputdt.""" - if any(t % outputdt != 0 for t in release_times): + if any((np.isfinite(t) and t % outputdt != 0) for t in release_times): warnings.warn( "Some of the particles have a start time that is not a multiple of outputdt. " "This could cause the first output to be at a different time than expected.", From c2e082e7bbb4badf32eff29cdc01fa598236da54 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:06:14 +0200 Subject: [PATCH 02/15] Patch raises --- parcels/particleset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index e6c819c69..84ac08b3b 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -191,9 +191,9 @@ def ArrayClass_init(self, *args, **kwargs): self.repeatdt = repeatdt.total_seconds() if isinstance(repeatdt, timedelta) else repeatdt if self.repeatdt: if self.repeatdt <= 0: - raise "Repeatdt should be > 0" + raise ValueError("Repeatdt should be > 0") if time[0] and not np.allclose(time, time[0]): - raise "All Particle.time should be the same when repeatdt is not None" + raise ValueError("All Particle.time should be the same when repeatdt is not None") self._repeatpclass = pclass self._repeatkwargs = kwargs self._repeatkwargs.pop("partition_function", None) From 10585455c4b4c7376d7fbb6a9c7722530344c665 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:16:02 +0200 Subject: [PATCH 03/15] Update var names Align with convention, avoid shadowing builtins --- parcels/fieldfilebuffer.py | 6 +++--- parcels/fieldset.py | 4 ++-- parcels/particle.py | 6 +++--- parcels/particledata.py | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index d1a8d199a..0a42f6f00 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -388,7 +388,7 @@ def close(self): self.chunk_mapping = None @classmethod - def add_to_dimension_name_map_global(self, name_map): + def add_to_dimension_name_map_global(cls, name_map): """ [externally callable] This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to @@ -406,9 +406,9 @@ def add_to_dimension_name_map_global(self, name_map): for pcls_dim_name in name_map.keys(): if isinstance(name_map[pcls_dim_name], list): for nc_dim_name in name_map[pcls_dim_name]: - self._static_name_maps[pcls_dim_name].append(nc_dim_name) + cls._static_name_maps[pcls_dim_name].append(nc_dim_name) elif isinstance(name_map[pcls_dim_name], str): - self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name]) + cls._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name]) def add_to_dimension_name_map(self, name_map): """ diff --git a/parcels/fieldset.py b/parcels/fieldset.py index e62ef45a8..aad0635fa 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -344,8 +344,8 @@ def check_velocityfields(U, V, W): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def parse_wildcards(self, *args, **kwargs): - return self._parse_wildcards(*args, **kwargs) + def parse_wildcards(cls, *args, **kwargs): + return cls._parse_wildcards(*args, **kwargs) @classmethod def _parse_wildcards(cls, paths, filenames, var): diff --git a/parcels/particle.py b/parcels/particle.py index a8f5ead1f..c7f1681e0 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -201,13 +201,13 @@ def __del__(self): def __repr__(self): time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}" - str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) + p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) for var in vars(type(self)): if var in ["lon_nextloop", "lat_nextloop", "depth_nextloop", "time_nextloop"]: continue if type(getattr(type(self), var)) is Variable and getattr(type(self), var).to_write is True: - str += f"{var}={getattr(self, var):f}, " - return str + f"time={time_string})" + p_string += f"{var}={getattr(self, var):f}, " + return p_string + f"time={time_string})" @classmethod def add_variable(cls, var, *args, **kwargs): diff --git a/parcels/particledata.py b/parcels/particledata.py index 6ccd55061..3eb725d03 100644 --- a/parcels/particledata.py +++ b/parcels/particledata.py @@ -460,7 +460,7 @@ def getPType(self): def __repr__(self): time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}" - str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) + p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) for var in self._pcoll.ptype.variables: if var.name in [ "lon_nextloop", @@ -470,8 +470,8 @@ def __repr__(self): ]: # TODO check if time_nextloop is needed (or can work with time-dt?) continue if var.to_write is not False and var.name not in ["id", "lon", "lat", "depth", "time"]: - str += f"{var.name}={getattr(self, var.name):f}, " - return str + f"time={time_string})" + p_string += f"{var.name}={getattr(self, var.name):f}, " + return p_string + f"time={time_string})" def delete(self): """Signal the particle for deletion.""" From 15cb6d2c009775d1742992a38ea395524502acd5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:50:33 +0200 Subject: [PATCH 04/15] add timedelta_to_float helper patch time=None particlefile.write --- parcels/field.py | 7 +++---- parcels/particlefile.py | 7 +++---- parcels/particleset.py | 35 +++++++++++++++++++---------------- parcels/tools/_helpers.py | 12 ++++++++++++ tests/tools/test_helpers.py | 22 +++++++++++++++++++++- 5 files changed, 58 insertions(+), 25 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index ed907747a..9bc155787 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1,5 +1,4 @@ import collections -import datetime import math import warnings from collections.abc import Iterable @@ -21,7 +20,7 @@ assert_valid_gridindexingtype, assert_valid_interp_method, ) -from parcels.tools._helpers import deprecated_made_private +from parcels.tools._helpers import deprecated_made_private, timedelta_to_float from parcels.tools.converters import ( Geographic, GeographicPolar, @@ -247,8 +246,8 @@ def __init__( "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." ) if self.time_periodic is not False: - if isinstance(self.time_periodic, datetime.timedelta): - self.time_periodic = self.time_periodic.total_seconds() + self.time_periodic = timedelta_to_float(self.time_periodic) + if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: raise ValueError("Time series provided is longer than the time_periodic parameter") diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 31011aadf..3f2aac36a 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -2,7 +2,6 @@ import os import warnings -from datetime import timedelta import numpy as np import xarray as xr @@ -10,7 +9,7 @@ import parcels from parcels._compat import MPI -from parcels.tools._helpers import deprecated, deprecated_made_private +from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float from parcels.tools.warnings import FileWarning __all__ = ["ParticleFile"] @@ -48,7 +47,7 @@ class ParticleFile: """ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True): - self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt + self._outputdt = timedelta_to_float(outputdt) self._chunks = chunks self._particleset = particleset self._parcels_mesh = "spherical" @@ -264,7 +263,7 @@ def write(self, pset, time, indices=None): time : Time at which to write ParticleSet """ - time = time.total_seconds() if isinstance(time, timedelta) else time + time = timedelta_to_float(time) if time is not None else None if pset.particledata._ncount == 0: warnings.warn( diff --git a/parcels/particleset.py b/parcels/particleset.py index 84ac08b3b..7b2193c12 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -27,7 +27,7 @@ from parcels.particle import JITParticle, Variable from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile -from parcels.tools._helpers import deprecated, deprecated_made_private +from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array from parcels.tools.global_statics import get_package_dir from parcels.tools.loggers import logger @@ -188,7 +188,8 @@ def ArrayClass_init(self, *args, **kwargs): lon.size == kwargs[kwvar].size ), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." - self.repeatdt = repeatdt.total_seconds() if isinstance(repeatdt, timedelta) else repeatdt + self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None + if self.repeatdt: if self.repeatdt <= 0: raise ValueError("Repeatdt should be > 0") @@ -981,13 +982,13 @@ def execute( pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, - runtime=None, - dt=1.0, + runtime: float | timedelta | np.timedelta64 | None = None, + dt: float | timedelta | np.timedelta64 = 1.0, output_file=None, verbose_progress=True, postIterationCallbacks=None, - callbackdt=None, - delete_cfiles=True, + callbackdt: float | timedelta | np.timedelta64 | None = None, + delete_cfiles: bool = True, ): """Execute a given kernel function over the particle set for multiple timesteps. @@ -1067,22 +1068,24 @@ def execute( if self.time_origin.calendar is None: raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double") endtime = self.time_origin.reltime(endtime) - if isinstance(runtime, timedelta): - runtime = runtime.total_seconds() - if isinstance(dt, timedelta): - dt = dt.total_seconds() + + if runtime is not None: + runtime = timedelta_to_float(runtime) + + dt = timedelta_to_float(dt) + if abs(dt) <= 1e-6: raise ValueError("Time step dt is too small") if (dt * 1e6) % 1 != 0: raise ValueError("Output interval should not have finer precision than 1e-6 s") - outputdt = output_file.outputdt if output_file else np.inf - if isinstance(outputdt, timedelta): - outputdt = outputdt.total_seconds() + + outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf + if outputdt is not None: _warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"]) - - if isinstance(callbackdt, timedelta): - callbackdt = callbackdt.total_seconds() + + if callbackdt is not None: + callbackdt = timedelta_to_float(callbackdt) assert runtime is None or runtime >= 0, "runtime must be positive" assert outputdt is None or outputdt >= 0, "outputdt must be positive" diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 381338f0e..24c8ba9c4 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -3,6 +3,9 @@ import functools import warnings from collections.abc import Callable +from datetime import timedelta + +import numpy as np PACKAGE = "Parcels" @@ -56,3 +59,12 @@ def deprecated_made_private(func: Callable) -> Callable: def patch_docstring(obj: Callable, extra: str) -> None: obj.__doc__ = f"{obj.__doc__ or ''}{extra}".strip() + + +def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: + """Convert a timedelta to a float in seconds.""" + if isinstance(dt, timedelta): + return dt.total_seconds() + if isinstance(dt, np.timedelta64): + return float(dt / np.timedelta64(1, "s")) + return float(dt) diff --git a/tests/tools/test_helpers.py b/tests/tools/test_helpers.py index c7510c21b..3e6f5a992 100644 --- a/tests/tools/test_helpers.py +++ b/tests/tools/test_helpers.py @@ -1,6 +1,9 @@ +from datetime import timedelta + +import numpy as np import pytest -from parcels.tools._helpers import deprecated, deprecated_made_private +from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float def test_deprecated(): @@ -53,3 +56,20 @@ def some_function(x, y): some_function(1, 2) assert "deprecated::" in some_function.__doc__ + + +@pytest.mark.parametrize( + "input, expected", + [ + (timedelta(days=1), 24 * 60 * 60), + (np.timedelta64(1, "D"), 24 * 60 * 60), + (3600.0, 3600.0), + ], +) +def test_timedelta_to_float(input, expected): + assert timedelta_to_float(input) == expected + + +def test_timedelta_to_float_exceptions(): + with pytest.raises((ValueError, TypeError)): + timedelta_to_float("invalid_type") From 11885acaa05de46f7041ccb06430ad41d6552c7b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:07:31 +0200 Subject: [PATCH 05/15] type annotations --- parcels/field.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 9bc155787..086881173 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from ctypes import POINTER, Structure, c_float, c_int, pointer from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import dask.array as da import numpy as np @@ -149,6 +149,8 @@ class Field: * `Nested Fields <../examples/tutorial_NestedFields.ipynb>`__ """ + _cast_data_dtype: type[np.float32] | type[np.float64] + def __init__( self, name: str | tuple[str, str], @@ -161,16 +163,16 @@ def __init__( mesh: Mesh = "flat", timestamps=None, fieldtype=None, - transpose=False, - vmin=None, - vmax=None, - cast_data_dtype="float32", - time_origin=None, + transpose: bool = False, + vmin: float | None = None, + vmax: float | None = None, + cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32", + time_origin: TimeConverter | None = None, interp_method: InterpMethod = "linear", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, gridindexingtype: GridIndexingType = "nemo", - to_write=False, + to_write: bool = False, **kwargs, ): if kwargs.get("netcdf_decodewarning") is not None: @@ -257,11 +259,19 @@ def __init__( self.vmin = vmin self.vmax = vmax - self._cast_data_dtype = cast_data_dtype - if self.cast_data_dtype == "float32": - self._cast_data_dtype = np.float32 - elif self.cast_data_dtype == "float64": - self._cast_data_dtype = np.float64 + + match cast_data_dtype: + case "float32": + self._cast_data_dtype = np.float32 + case "float64": + self._cast_data_dtype = np.float64 + case _: + self._cast_data_dtype = cast_data_dtype + + if self.cast_data_dtype not in [np.float32, np.float64]: + raise ValueError( + f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'" + ) if not self.grid.defer_load: self.data = self._reshape(self.data, transpose) From 65472a1182a7073e856adaf4e40180d99fff7de5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:26:13 +0200 Subject: [PATCH 06/15] TimeConverter testing --- parcels/particlefile.py | 3 +- parcels/tools/_helpers.py | 1 + tests/tools/test_converters.py | 67 ++++++++++++++++++++++++++++++---- 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 3f2aac36a..752cea7a2 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -2,6 +2,7 @@ import os import warnings +from datetime import timedelta import numpy as np import xarray as xr @@ -252,7 +253,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis): Z.append(a, axis=axis) zarr.consolidate_metadata(store) - def write(self, pset, time, indices=None): + def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None): """Write all data from one time step to the zarr file, before the particle locations are updated. diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 24c8ba9c4..9284110b4 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -61,6 +61,7 @@ def patch_docstring(obj: Callable, extra: str) -> None: obj.__doc__ = f"{obj.__doc__ or ''}{extra}".strip() +@np.vectorize def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: """Convert a timedelta to a float in seconds.""" if isinstance(dt, timedelta): diff --git a/tests/tools/test_converters.py b/tests/tools/test_converters.py index f8de9f8fa..03d5f0178 100644 --- a/tests/tools/test_converters.py +++ b/tests/tools/test_converters.py @@ -1,14 +1,65 @@ import cftime import numpy as np +import pytest from parcels.tools.converters import TimeConverter, _get_cftime_datetimes +cf_datetime_classes = [getattr(cftime, c) for c in _get_cftime_datetimes()] +cf_datetime_objects = [c(1990, 1, 1) for c in cf_datetime_classes] -def test_TimeConverter(): - cf_datetime_names = _get_cftime_datetimes() - for cf_datetime in cf_datetime_names: - date = getattr(cftime, cf_datetime)(1990, 1, 1) - assert TimeConverter(date).calendar == date.calendar - assert TimeConverter(None).calendar is None - date_datetime64 = np.datetime64("2001-01-01T12:00") - assert TimeConverter(date_datetime64).calendar == "np_datetime64" + +@pytest.mark.parametrize( + "cf_datetime", + cf_datetime_objects, +) +def test_TimeConverter_cf(cf_datetime): + assert TimeConverter(cf_datetime).calendar == cf_datetime.calendar + assert TimeConverter(cf_datetime).time_origin == cf_datetime + + +def test_TimeConverter_standard(): + dt = np.datetime64("2001-01-01T12:00") + assert TimeConverter(dt).calendar == "np_datetime64" + assert TimeConverter(dt).time_origin == dt + + dt = np.timedelta64(1, "s") + assert TimeConverter(dt).calendar == "np_timedelta64" + assert TimeConverter(dt).time_origin == dt + + assert TimeConverter(0).calendar is None + assert TimeConverter(0).time_origin == 0 + + +def test_TimeConverter_invalid(): + with pytest.raises(ValueError): + TimeConverter("invalid") + + +def test_TimeConverter_reltime_one_day(): + ONE_DAY = 24 * 60 * 60 + first_jan = [c(1990, 1, 1) for c in cf_datetime_classes] + [0] + second_jan = [c(1990, 1, 2) for c in cf_datetime_classes] + [ONE_DAY] + + for time_origin, time in zip(first_jan, second_jan, strict=True): + tc = TimeConverter(time_origin) + assert tc.reltime(time) == ONE_DAY + + with pytest.raises(ValueError): + tc.reltime("invalid") + + +@pytest.mark.parametrize( + "x, y", + [ + (np.datetime64("2001-01-01T12:00"), 0), + (np.timedelta64(1, "s"), 0), + (cftime.DatetimeNoLeap(1990, 1, 1), 0), + (cftime.DatetimeNoLeap(1990, 1, 1), cftime.DatetimeAllLeap(1991, 1, 1)), + (0, 0), + ], +) +def test_TimeConverter_reltime_errors(x, y): + """All of these should raise a ValueError when doing reltime""" + tc = TimeConverter(x) + with pytest.raises(ValueError, match="Cannot subtract 'time'"): + tc.reltime(y) From 7607aec55813b8bce1a5a28aa6789b2fa263dc42 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 15 Oct 2024 18:06:31 +0200 Subject: [PATCH 07/15] annotations timeconverter --- parcels/tools/converters.py | 20 ++++++++++---------- tests/tools/test_converters.py | 16 ++++++---------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index 3911bd676..6335f0f9e 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -1,4 +1,5 @@ -# flake8: noqa: E999 +from __future__ import annotations + import inspect from datetime import timedelta from math import cos, pi @@ -32,14 +33,14 @@ def convert_to_flat_array(var: npt.ArrayLike) -> npt.NDArray: return np.array(var).flatten() -def _get_cftime_datetimes(): +def _get_cftime_datetimes() -> list[str]: # Is there a more elegant way to parse these from cftime? cftime_calendars = tuple(x[1].__name__ for x in inspect.getmembers(cftime._cftime, inspect.isclass)) cftime_datetime_names = [ca for ca in cftime_calendars if "Datetime" in ca] return cftime_datetime_names -def _get_cftime_calendars(): +def _get_cftime_calendars() -> list[str]: return [getattr(cftime, cf_datetime)(1990, 1, 1).calendar for cf_datetime in _get_cftime_datetimes()] @@ -48,22 +49,21 @@ class TimeConverter: Parameters ---------- - time_origin : float, integer, numpy.datetime64 or netcdftime.DatetimeNoLeap + time_origin : float, integer, numpy.datetime64 or cftime.DatetimeNoLeap time origin of the class. """ - def __init__(self, time_origin=0): - self.time_origin = 0 if time_origin is None else time_origin + def __init__(self, time_origin: float | np.datetime64 | np.timedelta64 | cftime.datetime = 0): + self.time_origin = time_origin + self.calendar: str | None = None if isinstance(time_origin, np.datetime64): self.calendar = "np_datetime64" elif isinstance(time_origin, np.timedelta64): self.calendar = "np_timedelta64" - elif isinstance(time_origin, cftime._cftime.datetime): + elif isinstance(time_origin, cftime.datetime): self.calendar = time_origin.calendar - else: - self.calendar = None - def reltime(self, time): + def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float: """Method to compute the difference, in seconds, between a time and the time_origin of the TimeConverter diff --git a/tests/tools/test_converters.py b/tests/tools/test_converters.py index 03d5f0178..fc82067f9 100644 --- a/tests/tools/test_converters.py +++ b/tests/tools/test_converters.py @@ -31,7 +31,7 @@ def test_TimeConverter_standard(): def test_TimeConverter_invalid(): - with pytest.raises(ValueError): + with pytest.raises(TypeError): TimeConverter("invalid") @@ -44,22 +44,18 @@ def test_TimeConverter_reltime_one_day(): tc = TimeConverter(time_origin) assert tc.reltime(time) == ONE_DAY - with pytest.raises(ValueError): - tc.reltime("invalid") - @pytest.mark.parametrize( "x, y", [ - (np.datetime64("2001-01-01T12:00"), 0), - (np.timedelta64(1, "s"), 0), - (cftime.DatetimeNoLeap(1990, 1, 1), 0), - (cftime.DatetimeNoLeap(1990, 1, 1), cftime.DatetimeAllLeap(1991, 1, 1)), - (0, 0), + pytest.param(np.datetime64("2001-01-01T12:00"), 0, id="datetime64 float"), + pytest.param(np.timedelta64(1, "s"), 0, id="timedelta64 float"), + pytest.param(cftime.DatetimeNoLeap(1990, 1, 1), 0, id="cftime float"), + pytest.param(cftime.DatetimeNoLeap(1990, 1, 1), cftime.DatetimeAllLeap(1991, 1, 1), id="cftime cftime"), ], ) def test_TimeConverter_reltime_errors(x, y): """All of these should raise a ValueError when doing reltime""" tc = TimeConverter(x) - with pytest.raises(ValueError, match="Cannot subtract 'time'"): + with pytest.raises((ValueError, TypeError)): tc.reltime(y) From 2499ffb54e114cb68dc197745ca7407a335359ce Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:20:57 +0200 Subject: [PATCH 08/15] Remove xfail from test Fixes #1737 remove xfail test flag (fixes #1737) --- docs/examples/example_globcurrent.py | 12 +++++------- tests/test_grids.py | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/examples/example_globcurrent.py b/docs/examples/example_globcurrent.py index d4929c01a..b4049256b 100755 --- a/docs/examples/example_globcurrent.py +++ b/docs/examples/example_globcurrent.py @@ -219,23 +219,21 @@ def test__particles_init_time(): assert pset[0].time - pset4[0].time == 0 -@pytest.mark.xfail(reason="Time extrapolation error expected to be thrown", strict=True) @pytest.mark.parametrize("mode", ["scipy", "jit"]) @pytest.mark.parametrize("use_xarray", [True, False]) def test_globcurrent_time_extrapolation_error(mode, use_xarray): fieldset = set_globcurrent_fieldset(use_xarray=use_xarray) - pset = parcels.ParticleSet( fieldset, pclass=ptype[mode], lon=[25], lat=[-35], - time=fieldset.U.time[0] - timedelta(days=1).total_seconds(), - ) - - pset.execute( - parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5) + time=fieldset.U.grid.time[0] - timedelta(days=1).total_seconds(), ) + with pytest.raises(parcels.TimeExtrapolationError): + pset.execute( + parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5) + ) @pytest.mark.parametrize("mode", ["scipy", "jit"]) diff --git a/tests/test_grids.py b/tests/test_grids.py index 2d036dd35..5e336af11 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -104,12 +104,12 @@ def sampleTemp(particle, fieldset, time): assert np.all([np.isclose(p.temp0, p.temp1, atol=1e-3) for p in pset]) -@pytest.mark.xfail(reason="Grid cannot be computed using a time vector which is neither float nor int", strict=True) def test_time_format_in_grid(): lon = np.linspace(0, 1, 2, dtype=np.float32) lat = np.linspace(0, 1, 2, dtype=np.float32) time = np.array([np.datetime64("2000-01-01")] * 2) - RectilinearZGrid(lon, lat, time=time) + with pytest.raises(AssertionError, match="Time vector"): + RectilinearZGrid(lon, lat, time=time) def test_avoid_repeated_grids(): From f00fbc43e9cb3177d2bb299b80eda956abfb8e09 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:54:14 +0200 Subject: [PATCH 09/15] remove dead code contributes to #1620 --- parcels/kernel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parcels/kernel.py b/parcels/kernel.py index 3cf7c26b4..d21592716 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -77,7 +77,6 @@ def __init__( self.funccode = funccode self.py_ast = py_ast self.dyn_srcs = [] - self.static_srcs = [] self.src_file = None self.lib_file = None self.log_file = None From 52680be995a976f0c5aa76e0f80291327b51be5d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:25:02 +0200 Subject: [PATCH 10/15] refactor cleanup_remove_files contributes to #1620 --- parcels/kernel.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/parcels/kernel.py b/parcels/kernel.py index d21592716..0d9cf4401 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -561,9 +561,11 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs): def cleanup_remove_files(lib_file, all_files_array, delete_cfiles): if lib_file is not None: if os.path.isfile(lib_file): # and delete_cfiles - [os.remove(s) for s in [lib_file] if os.path is not None and os.path.exists(s)] - if delete_cfiles and len(all_files_array) > 0: - [os.remove(s) for s in all_files_array if os.path is not None and os.path.exists(s)] + os.remove(lib_file) + if delete_cfiles: + for s in all_files_array: + if os.path.exists(s): + os.remove(s) @staticmethod def cleanup_unload_lib(lib): @@ -600,8 +602,7 @@ def load_fieldset_jit(self, pset): g._load_chunk == g._chunk_loading_requested, g._chunk_loaded_touched, g._load_chunk ) if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel - if not g._load_chunk.flags["C_CONTIGUOUS"]: - g._load_chunk = np.array(g._load_chunk, order="C") + g._load_chunk = np.array(g._load_chunk, order="C") if not g.depth.flags.c_contiguous: g._depth = np.array(g.depth, order="C") if not g.lon.flags.c_contiguous: From 4a5d5f59db1adf5a0b2703bb2c3f4f9f5f622e17 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:41:14 +0200 Subject: [PATCH 11/15] patch unit tests --- tests/tools/test_converters.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/tools/test_converters.py b/tests/tools/test_converters.py index fc82067f9..3d35363b1 100644 --- a/tests/tools/test_converters.py +++ b/tests/tools/test_converters.py @@ -30,11 +30,6 @@ def test_TimeConverter_standard(): assert TimeConverter(0).time_origin == 0 -def test_TimeConverter_invalid(): - with pytest.raises(TypeError): - TimeConverter("invalid") - - def test_TimeConverter_reltime_one_day(): ONE_DAY = 24 * 60 * 60 first_jan = [c(1990, 1, 1) for c in cf_datetime_classes] + [0] @@ -49,7 +44,6 @@ def test_TimeConverter_reltime_one_day(): "x, y", [ pytest.param(np.datetime64("2001-01-01T12:00"), 0, id="datetime64 float"), - pytest.param(np.timedelta64(1, "s"), 0, id="timedelta64 float"), pytest.param(cftime.DatetimeNoLeap(1990, 1, 1), 0, id="cftime float"), pytest.param(cftime.DatetimeNoLeap(1990, 1, 1), cftime.DatetimeAllLeap(1991, 1, 1), id="cftime cftime"), ], From 338fe9a0df2e9525a01f2e4815b4db7554fce1c5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:01:44 +0200 Subject: [PATCH 12/15] review feedback --- parcels/kernel.py | 3 ++- parcels/tools/_helpers.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/kernel.py b/parcels/kernel.py index 0d9cf4401..46395ad43 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -602,7 +602,8 @@ def load_fieldset_jit(self, pset): g._load_chunk == g._chunk_loading_requested, g._chunk_loaded_touched, g._load_chunk ) if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel - g._load_chunk = np.array(g._load_chunk, order="C") + if not g._load_chunk.flags["C_CONTIGUOUS"]: + g._load_chunk = np.array(g._load_chunk, order="C") if not g.depth.flags.c_contiguous: g._depth = np.array(g.depth, order="C") if not g.lon.flags.c_contiguous: diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 9284110b4..24c8ba9c4 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -61,7 +61,6 @@ def patch_docstring(obj: Callable, extra: str) -> None: obj.__doc__ = f"{obj.__doc__ or ''}{extra}".strip() -@np.vectorize def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: """Convert a timedelta to a float in seconds.""" if isinstance(dt, timedelta): From 69b109458e53390c4d2ab9c5dad3ad040bd1dc2f Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 28 Oct 2024 08:50:03 +0100 Subject: [PATCH 13/15] Add type ignore comments Postpone typing to a future refactoring (particularly for reltime) --- parcels/field.py | 2 +- parcels/particlefile.py | 10 +++++----- parcels/tools/converters.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 086881173..c7f5d2eac 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -806,7 +806,7 @@ def from_xarray( lat = da[dimensions["lat"]].values time_origin = TimeConverter(time[0]) - time = time_origin.reltime(time) + time = time_origin.reltime(time) # type: ignore[assignment] grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) kwargs["time_periodic"] = time_periodic diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 752cea7a2..c8482b9e6 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -295,7 +295,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N if self.create_new_zarrfile: if self.chunks is None: self._chunks = (len(ids), 1) - if pset._repeatpclass is not None and self.chunks[0] < 1e4: + if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index] warnings.warn( f"ParticleFile chunks are set to {self.chunks}, but this may lead to " f"a significant slowdown in Parcels when many calls to repeatdt. " @@ -303,10 +303,10 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N FileWarning, stacklevel=2, ) - if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): - arrsize = (self._maxids, self.chunks[1]) + if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index] + arrsize = (self._maxids, self.chunks[1]) # type: ignore[index] else: - arrsize = (len(ids), self.chunks[1]) + arrsize = (len(ids), self.chunks[1]) # type: ignore[index] ds = xr.Dataset( attrs=self.metadata, coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, @@ -331,7 +331,7 @@ def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=N data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) - ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks + ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index] ds.to_zarr(self.fname, mode="w") self._create_new_zarrfile = False else: diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index 6335f0f9e..3f323b93c 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -63,7 +63,7 @@ def __init__(self, time_origin: float | np.datetime64 | np.timedelta64 | cftime. elif isinstance(time_origin, cftime.datetime): self.calendar = time_origin.calendar - def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float: + def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float | npt.NDArray: """Method to compute the difference, in seconds, between a time and the time_origin of the TimeConverter @@ -80,11 +80,11 @@ def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime. """ time = time.time_origin if isinstance(time, TimeConverter) else time if self.calendar in ["np_datetime64", "np_timedelta64"]: - return (time - self.time_origin) / np.timedelta64(1, "s") + return (time - self.time_origin) / np.timedelta64(1, "s") # type: ignore elif self.calendar in _get_cftime_calendars(): if isinstance(time, (list, np.ndarray)): try: - return np.array([(t - self.time_origin).total_seconds() for t in time]) + return np.array([(t - self.time_origin).total_seconds() for t in time]) # type: ignore except ValueError: raise ValueError( f"Cannot subtract 'time' (a {type(time)} object) from a {self.calendar} calendar.\n" @@ -92,14 +92,14 @@ def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime. ) else: try: - return (time - self.time_origin).total_seconds() + return (time - self.time_origin).total_seconds() # type: ignore except ValueError: raise ValueError( f"Cannot subtract 'time' (a {type(time)} object) from a {self.calendar} calendar.\n" f"Provide 'time' as a {type(self.time_origin)} object?" ) elif self.calendar is None: - return time - self.time_origin + return time - self.time_origin # type: ignore else: raise RuntimeError(f"Calendar {self.calendar} not implemented in TimeConverter") From d96644bac98a049abf9bc9b8d1a7957b727b6a5f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 07:54:17 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- parcels/particleset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index 7b2193c12..ead99ffc0 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1083,7 +1083,7 @@ def execute( if outputdt is not None: _warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"]) - + if callbackdt is not None: callbackdt = timedelta_to_float(callbackdt) From 64bda71e9f0137b44ca81da7536739e38b1f165d Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 28 Oct 2024 09:02:58 +0100 Subject: [PATCH 15/15] Update parcels/particleset.py Co-authored-by: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> --- parcels/particleset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index c4e10d4e5..706126066 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1078,7 +1078,7 @@ def execute( outputdt = output_file.outputdt if output_file else np.inf if isinstance(outputdt, timedelta): outputdt = outputdt.total_seconds() - if np.isfinite(outputdt) is not None: + if np.isfinite(outputdt): _warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"]) if isinstance(callbackdt, timedelta):