diff --git a/docs/examples/example_globcurrent.py b/docs/examples/example_globcurrent.py index d4929c01a..b4049256b 100755 --- a/docs/examples/example_globcurrent.py +++ b/docs/examples/example_globcurrent.py @@ -219,23 +219,21 @@ def test__particles_init_time(): assert pset[0].time - pset4[0].time == 0 -@pytest.mark.xfail(reason="Time extrapolation error expected to be thrown", strict=True) @pytest.mark.parametrize("mode", ["scipy", "jit"]) @pytest.mark.parametrize("use_xarray", [True, False]) def test_globcurrent_time_extrapolation_error(mode, use_xarray): fieldset = set_globcurrent_fieldset(use_xarray=use_xarray) - pset = parcels.ParticleSet( fieldset, pclass=ptype[mode], lon=[25], lat=[-35], - time=fieldset.U.time[0] - timedelta(days=1).total_seconds(), - ) - - pset.execute( - parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5) + time=fieldset.U.grid.time[0] - timedelta(days=1).total_seconds(), ) + with pytest.raises(parcels.TimeExtrapolationError): + pset.execute( + parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5) + ) @pytest.mark.parametrize("mode", ["scipy", "jit"]) diff --git a/parcels/field.py b/parcels/field.py index 267337c8b..a8a79dfc6 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1,11 +1,10 @@ import collections -import datetime import math import warnings from collections.abc import Iterable from ctypes import POINTER, Structure, c_float, c_int, pointer from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import dask.array as da import numpy as np @@ -21,7 +20,7 @@ assert_valid_gridindexingtype, assert_valid_interp_method, ) -from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr +from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr, timedelta_to_float from parcels.tools.converters import ( Geographic, GeographicPolar, @@ -152,6 +151,7 @@ class Field: allow_time_extrapolation: bool time_periodic: TimePeriodic + _cast_data_dtype: type[np.float32] | type[np.float64] def __init__( self, @@ -165,16 +165,16 @@ def __init__( mesh: Mesh = "flat", timestamps=None, fieldtype=None, - transpose=False, - vmin=None, - vmax=None, - cast_data_dtype="float32", - time_origin=None, + transpose: bool = False, + vmin: float | None = None, + vmax: float | None = None, + cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32", + time_origin: TimeConverter | None = None, interp_method: InterpMethod = "linear", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, gridindexingtype: GridIndexingType = "nemo", - to_write=False, + to_write: bool = False, **kwargs, ): if kwargs.get("netcdf_decodewarning") is not None: @@ -250,8 +250,8 @@ def __init__( "Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object." ) if self.time_periodic is not False: - if isinstance(self.time_periodic, datetime.timedelta): - self.time_periodic = self.time_periodic.total_seconds() + self.time_periodic = timedelta_to_float(self.time_periodic) + if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic): if self.grid.time[-1] - self.grid.time[0] > self.time_periodic: raise ValueError("Time series provided is longer than the time_periodic parameter") @@ -261,11 +261,19 @@ def __init__( self.vmin = vmin self.vmax = vmax - self._cast_data_dtype = cast_data_dtype - if self.cast_data_dtype == "float32": - self._cast_data_dtype = np.float32 - elif self.cast_data_dtype == "float64": - self._cast_data_dtype = np.float64 + + match cast_data_dtype: + case "float32": + self._cast_data_dtype = np.float32 + case "float64": + self._cast_data_dtype = np.float64 + case _: + self._cast_data_dtype = cast_data_dtype + + if self.cast_data_dtype not in [np.float32, np.float64]: + raise ValueError( + f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'" + ) if not self.grid.defer_load: self.data = self._reshape(self.data, transpose) @@ -803,7 +811,7 @@ def from_xarray( lat = da[dimensions["lat"]].values time_origin = TimeConverter(time[0]) - time = time_origin.reltime(time) + time = time_origin.reltime(time) # type: ignore[assignment] grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) kwargs["time_periodic"] = time_periodic diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index d1a8d199a..0a42f6f00 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -388,7 +388,7 @@ def close(self): self.chunk_mapping = None @classmethod - def add_to_dimension_name_map_global(self, name_map): + def add_to_dimension_name_map_global(cls, name_map): """ [externally callable] This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to @@ -406,9 +406,9 @@ def add_to_dimension_name_map_global(self, name_map): for pcls_dim_name in name_map.keys(): if isinstance(name_map[pcls_dim_name], list): for nc_dim_name in name_map[pcls_dim_name]: - self._static_name_maps[pcls_dim_name].append(nc_dim_name) + cls._static_name_maps[pcls_dim_name].append(nc_dim_name) elif isinstance(name_map[pcls_dim_name], str): - self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name]) + cls._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name]) def add_to_dimension_name_map(self, name_map): """ diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 4f01f291c..9a7fd39da 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -347,8 +347,8 @@ def check_velocityfields(U, V, W): @classmethod @deprecated_made_private # TODO: Remove 6 months after v3.1.0 - def parse_wildcards(self, *args, **kwargs): - return self._parse_wildcards(*args, **kwargs) + def parse_wildcards(cls, *args, **kwargs): + return cls._parse_wildcards(*args, **kwargs) @classmethod def _parse_wildcards(cls, paths, filenames, var): diff --git a/parcels/kernel.py b/parcels/kernel.py index 3cf7c26b4..46395ad43 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -77,7 +77,6 @@ def __init__( self.funccode = funccode self.py_ast = py_ast self.dyn_srcs = [] - self.static_srcs = [] self.src_file = None self.lib_file = None self.log_file = None @@ -562,9 +561,11 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs): def cleanup_remove_files(lib_file, all_files_array, delete_cfiles): if lib_file is not None: if os.path.isfile(lib_file): # and delete_cfiles - [os.remove(s) for s in [lib_file] if os.path is not None and os.path.exists(s)] - if delete_cfiles and len(all_files_array) > 0: - [os.remove(s) for s in all_files_array if os.path is not None and os.path.exists(s)] + os.remove(lib_file) + if delete_cfiles: + for s in all_files_array: + if os.path.exists(s): + os.remove(s) @staticmethod def cleanup_unload_lib(lib): diff --git a/parcels/particle.py b/parcels/particle.py index 728e0deaa..87dce269f 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -201,13 +201,13 @@ def __del__(self): def __repr__(self): time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}" - str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) + p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) for var in vars(type(self)): if var in ["lon_nextloop", "lat_nextloop", "depth_nextloop", "time_nextloop"]: continue if type(getattr(type(self), var)) is Variable and getattr(type(self), var).to_write is True: - str += f"{var}={getattr(self, var):f}, " - return str + f"time={time_string})" + p_string += f"{var}={getattr(self, var):f}, " + return p_string + f"time={time_string})" @classmethod def add_variable(cls, var, *args, **kwargs): diff --git a/parcels/particledata.py b/parcels/particledata.py index 6ccd55061..3eb725d03 100644 --- a/parcels/particledata.py +++ b/parcels/particledata.py @@ -460,7 +460,7 @@ def getPType(self): def __repr__(self): time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}" - str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) + p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth) for var in self._pcoll.ptype.variables: if var.name in [ "lon_nextloop", @@ -470,8 +470,8 @@ def __repr__(self): ]: # TODO check if time_nextloop is needed (or can work with time-dt?) continue if var.to_write is not False and var.name not in ["id", "lon", "lat", "depth", "time"]: - str += f"{var.name}={getattr(self, var.name):f}, " - return str + f"time={time_string})" + p_string += f"{var.name}={getattr(self, var.name):f}, " + return p_string + f"time={time_string})" def delete(self): """Signal the particle for deletion.""" diff --git a/parcels/particlefile.py b/parcels/particlefile.py index bf2226d47..4ab406b4a 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -10,7 +10,7 @@ import parcels from parcels._compat import MPI -from parcels.tools._helpers import default_repr, deprecated, deprecated_made_private +from parcels.tools._helpers import default_repr, deprecated, deprecated_made_private, timedelta_to_float from parcels.tools.warnings import FileWarning __all__ = ["ParticleFile"] @@ -48,7 +48,7 @@ class ParticleFile: """ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True): - self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt + self._outputdt = timedelta_to_float(outputdt) self._chunks = chunks self._particleset = particleset self._parcels_mesh = "spherical" @@ -263,7 +263,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis): Z.append(a, axis=axis) zarr.consolidate_metadata(store) - def write(self, pset, time, indices=None): + def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None): """Write all data from one time step to the zarr file, before the particle locations are updated. @@ -274,7 +274,7 @@ def write(self, pset, time, indices=None): time : Time at which to write ParticleSet """ - time = time.total_seconds() if isinstance(time, timedelta) else time + time = timedelta_to_float(time) if time is not None else None if pset.particledata._ncount == 0: warnings.warn( @@ -305,7 +305,7 @@ def write(self, pset, time, indices=None): if self.create_new_zarrfile: if self.chunks is None: self._chunks = (len(ids), 1) - if pset._repeatpclass is not None and self.chunks[0] < 1e4: + if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index] warnings.warn( f"ParticleFile chunks are set to {self.chunks}, but this may lead to " f"a significant slowdown in Parcels when many calls to repeatdt. " @@ -313,10 +313,10 @@ def write(self, pset, time, indices=None): FileWarning, stacklevel=2, ) - if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): - arrsize = (self._maxids, self.chunks[1]) + if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index] + arrsize = (self._maxids, self.chunks[1]) # type: ignore[index] else: - arrsize = (len(ids), self.chunks[1]) + arrsize = (len(ids), self.chunks[1]) # type: ignore[index] ds = xr.Dataset( attrs=self.metadata, coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, @@ -341,7 +341,7 @@ def write(self, pset, time, indices=None): data[ids, 0] = pset.particledata.getvardata(var, indices_to_write) dims = ["trajectory", "obs"] ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout]) - ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks + ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index] ds.to_zarr(self.fname, mode="w") self._create_new_zarrfile = False else: diff --git a/parcels/particleset.py b/parcels/particleset.py index 207b851be..ffe7eb652 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -27,7 +27,7 @@ from parcels.particle import JITParticle, Variable from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile -from parcels.tools._helpers import deprecated, deprecated_made_private, particleset_repr +from parcels.tools._helpers import deprecated, deprecated_made_private, particleset_repr, timedelta_to_float from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array from parcels.tools.global_statics import get_package_dir from parcels.tools.loggers import logger @@ -189,12 +189,13 @@ def ArrayClass_init(self, *args, **kwargs): lon.size == kwargs[kwvar].size ), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." - self.repeatdt = repeatdt.total_seconds() if isinstance(repeatdt, timedelta) else repeatdt + self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None + if self.repeatdt: if self.repeatdt <= 0: - raise "Repeatdt should be > 0" + raise ValueError("Repeatdt should be > 0") if time[0] and not np.allclose(time, time[0]): - raise "All Particle.time should be the same when repeatdt is not None" + raise ValueError("All Particle.time should be the same when repeatdt is not None") self._repeatpclass = pclass self._repeatkwargs = kwargs self._repeatkwargs.pop("partition_function", None) @@ -986,13 +987,13 @@ def execute( pyfunc=AdvectionRK4, pyfunc_inter=None, endtime=None, - runtime=None, - dt=1.0, + runtime: float | timedelta | np.timedelta64 | None = None, + dt: float | timedelta | np.timedelta64 = 1.0, output_file=None, verbose_progress=True, postIterationCallbacks=None, - callbackdt=None, - delete_cfiles=True, + callbackdt: float | timedelta | np.timedelta64 | None = None, + delete_cfiles: bool = True, ): """Execute a given kernel function over the particle set for multiple timesteps. @@ -1072,22 +1073,23 @@ def execute( if self.time_origin.calendar is None: raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double") endtime = self.time_origin.reltime(endtime) - if isinstance(runtime, timedelta): - runtime = runtime.total_seconds() - if isinstance(dt, timedelta): - dt = dt.total_seconds() + + if runtime is not None: + runtime = timedelta_to_float(runtime) + + dt = timedelta_to_float(dt) + if abs(dt) <= 1e-6: raise ValueError("Time step dt is too small") if (dt * 1e6) % 1 != 0: raise ValueError("Output interval should not have finer precision than 1e-6 s") - outputdt = output_file.outputdt if output_file else np.inf - if isinstance(outputdt, timedelta): - outputdt = outputdt.total_seconds() - if outputdt is not None: + outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf + + if np.isfinite(outputdt): _warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"]) - if isinstance(callbackdt, timedelta): - callbackdt = callbackdt.total_seconds() + if callbackdt is not None: + callbackdt = timedelta_to_float(callbackdt) assert runtime is None or runtime >= 0, "runtime must be positive" assert outputdt is None or outputdt >= 0, "outputdt must be positive" @@ -1240,7 +1242,7 @@ def execute( def _warn_outputdt_release_desync(outputdt: float, release_times: Iterable[float]): """Gives the user a warning if the release time isn't a multiple of outputdt.""" - if any(t % outputdt != 0 for t in release_times): + if any((np.isfinite(t) and t % outputdt != 0) for t in release_times): warnings.warn( "Some of the particles have a start time that is not a multiple of outputdt. " "This could cause the first output to be at a different time than expected.", diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 89a299cde..9078690cc 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -6,8 +6,11 @@ import textwrap import warnings from collections.abc import Callable +from datetime import timedelta from typing import TYPE_CHECKING, Any +import numpy as np + if TYPE_CHECKING: from parcels import Field, FieldSet, ParticleSet @@ -134,3 +137,12 @@ def fieldset_repr(fieldset: FieldSet) -> str: def default_repr(obj: Any): return object.__repr__(obj) + + +def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float: + """Convert a timedelta to a float in seconds.""" + if isinstance(dt, timedelta): + return dt.total_seconds() + if isinstance(dt, np.timedelta64): + return float(dt / np.timedelta64(1, "s")) + return float(dt) diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index 3911bd676..3f323b93c 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -1,4 +1,5 @@ -# flake8: noqa: E999 +from __future__ import annotations + import inspect from datetime import timedelta from math import cos, pi @@ -32,14 +33,14 @@ def convert_to_flat_array(var: npt.ArrayLike) -> npt.NDArray: return np.array(var).flatten() -def _get_cftime_datetimes(): +def _get_cftime_datetimes() -> list[str]: # Is there a more elegant way to parse these from cftime? cftime_calendars = tuple(x[1].__name__ for x in inspect.getmembers(cftime._cftime, inspect.isclass)) cftime_datetime_names = [ca for ca in cftime_calendars if "Datetime" in ca] return cftime_datetime_names -def _get_cftime_calendars(): +def _get_cftime_calendars() -> list[str]: return [getattr(cftime, cf_datetime)(1990, 1, 1).calendar for cf_datetime in _get_cftime_datetimes()] @@ -48,22 +49,21 @@ class TimeConverter: Parameters ---------- - time_origin : float, integer, numpy.datetime64 or netcdftime.DatetimeNoLeap + time_origin : float, integer, numpy.datetime64 or cftime.DatetimeNoLeap time origin of the class. """ - def __init__(self, time_origin=0): - self.time_origin = 0 if time_origin is None else time_origin + def __init__(self, time_origin: float | np.datetime64 | np.timedelta64 | cftime.datetime = 0): + self.time_origin = time_origin + self.calendar: str | None = None if isinstance(time_origin, np.datetime64): self.calendar = "np_datetime64" elif isinstance(time_origin, np.timedelta64): self.calendar = "np_timedelta64" - elif isinstance(time_origin, cftime._cftime.datetime): + elif isinstance(time_origin, cftime.datetime): self.calendar = time_origin.calendar - else: - self.calendar = None - def reltime(self, time): + def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float | npt.NDArray: """Method to compute the difference, in seconds, between a time and the time_origin of the TimeConverter @@ -80,11 +80,11 @@ def reltime(self, time): """ time = time.time_origin if isinstance(time, TimeConverter) else time if self.calendar in ["np_datetime64", "np_timedelta64"]: - return (time - self.time_origin) / np.timedelta64(1, "s") + return (time - self.time_origin) / np.timedelta64(1, "s") # type: ignore elif self.calendar in _get_cftime_calendars(): if isinstance(time, (list, np.ndarray)): try: - return np.array([(t - self.time_origin).total_seconds() for t in time]) + return np.array([(t - self.time_origin).total_seconds() for t in time]) # type: ignore except ValueError: raise ValueError( f"Cannot subtract 'time' (a {type(time)} object) from a {self.calendar} calendar.\n" @@ -92,14 +92,14 @@ def reltime(self, time): ) else: try: - return (time - self.time_origin).total_seconds() + return (time - self.time_origin).total_seconds() # type: ignore except ValueError: raise ValueError( f"Cannot subtract 'time' (a {type(time)} object) from a {self.calendar} calendar.\n" f"Provide 'time' as a {type(self.time_origin)} object?" ) elif self.calendar is None: - return time - self.time_origin + return time - self.time_origin # type: ignore else: raise RuntimeError(f"Calendar {self.calendar} not implemented in TimeConverter") diff --git a/tests/test_grids.py b/tests/test_grids.py index 2d036dd35..5e336af11 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -104,12 +104,12 @@ def sampleTemp(particle, fieldset, time): assert np.all([np.isclose(p.temp0, p.temp1, atol=1e-3) for p in pset]) -@pytest.mark.xfail(reason="Grid cannot be computed using a time vector which is neither float nor int", strict=True) def test_time_format_in_grid(): lon = np.linspace(0, 1, 2, dtype=np.float32) lat = np.linspace(0, 1, 2, dtype=np.float32) time = np.array([np.datetime64("2000-01-01")] * 2) - RectilinearZGrid(lon, lat, time=time) + with pytest.raises(AssertionError, match="Time vector"): + RectilinearZGrid(lon, lat, time=time) def test_avoid_repeated_grids(): diff --git a/tests/tools/test_converters.py b/tests/tools/test_converters.py index f8de9f8fa..3d35363b1 100644 --- a/tests/tools/test_converters.py +++ b/tests/tools/test_converters.py @@ -1,14 +1,55 @@ import cftime import numpy as np +import pytest from parcels.tools.converters import TimeConverter, _get_cftime_datetimes +cf_datetime_classes = [getattr(cftime, c) for c in _get_cftime_datetimes()] +cf_datetime_objects = [c(1990, 1, 1) for c in cf_datetime_classes] -def test_TimeConverter(): - cf_datetime_names = _get_cftime_datetimes() - for cf_datetime in cf_datetime_names: - date = getattr(cftime, cf_datetime)(1990, 1, 1) - assert TimeConverter(date).calendar == date.calendar - assert TimeConverter(None).calendar is None - date_datetime64 = np.datetime64("2001-01-01T12:00") - assert TimeConverter(date_datetime64).calendar == "np_datetime64" + +@pytest.mark.parametrize( + "cf_datetime", + cf_datetime_objects, +) +def test_TimeConverter_cf(cf_datetime): + assert TimeConverter(cf_datetime).calendar == cf_datetime.calendar + assert TimeConverter(cf_datetime).time_origin == cf_datetime + + +def test_TimeConverter_standard(): + dt = np.datetime64("2001-01-01T12:00") + assert TimeConverter(dt).calendar == "np_datetime64" + assert TimeConverter(dt).time_origin == dt + + dt = np.timedelta64(1, "s") + assert TimeConverter(dt).calendar == "np_timedelta64" + assert TimeConverter(dt).time_origin == dt + + assert TimeConverter(0).calendar is None + assert TimeConverter(0).time_origin == 0 + + +def test_TimeConverter_reltime_one_day(): + ONE_DAY = 24 * 60 * 60 + first_jan = [c(1990, 1, 1) for c in cf_datetime_classes] + [0] + second_jan = [c(1990, 1, 2) for c in cf_datetime_classes] + [ONE_DAY] + + for time_origin, time in zip(first_jan, second_jan, strict=True): + tc = TimeConverter(time_origin) + assert tc.reltime(time) == ONE_DAY + + +@pytest.mark.parametrize( + "x, y", + [ + pytest.param(np.datetime64("2001-01-01T12:00"), 0, id="datetime64 float"), + pytest.param(cftime.DatetimeNoLeap(1990, 1, 1), 0, id="cftime float"), + pytest.param(cftime.DatetimeNoLeap(1990, 1, 1), cftime.DatetimeAllLeap(1991, 1, 1), id="cftime cftime"), + ], +) +def test_TimeConverter_reltime_errors(x, y): + """All of these should raise a ValueError when doing reltime""" + tc = TimeConverter(x) + with pytest.raises((ValueError, TypeError)): + tc.reltime(y) diff --git a/tests/tools/test_helpers.py b/tests/tools/test_helpers.py index c3499b55a..1403b679f 100644 --- a/tests/tools/test_helpers.py +++ b/tests/tools/test_helpers.py @@ -1,7 +1,10 @@ +from datetime import timedelta + +import numpy as np import pytest import parcels.tools._helpers as helpers -from parcels.tools._helpers import deprecated, deprecated_made_private +from parcels.tools._helpers import deprecated, deprecated_made_private, timedelta_to_float def test_format_list_items_multiline(): @@ -64,3 +67,20 @@ def some_function(x, y): some_function(1, 2) assert "deprecated::" in some_function.__doc__ + + +@pytest.mark.parametrize( + "input, expected", + [ + (timedelta(days=1), 24 * 60 * 60), + (np.timedelta64(1, "D"), 24 * 60 * 60), + (3600.0, 3600.0), + ], +) +def test_timedelta_to_float(input, expected): + assert timedelta_to_float(input) == expected + + +def test_timedelta_to_float_exceptions(): + with pytest.raises((ValueError, TypeError)): + timedelta_to_float("invalid_type")