Skip to content

Commit

Permalink
Merge branch 'master' into better-prints
Browse files Browse the repository at this point in the history
  • Loading branch information
VeckoTheGecko committed Oct 28, 2024
2 parents 02d36fc + bbe8448 commit d4ea6a2
Show file tree
Hide file tree
Showing 14 changed files with 174 additions and 92 deletions.
12 changes: 5 additions & 7 deletions docs/examples/example_globcurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,21 @@ def test__particles_init_time():
assert pset[0].time - pset4[0].time == 0


@pytest.mark.xfail(reason="Time extrapolation error expected to be thrown", strict=True)
@pytest.mark.parametrize("mode", ["scipy", "jit"])
@pytest.mark.parametrize("use_xarray", [True, False])
def test_globcurrent_time_extrapolation_error(mode, use_xarray):
fieldset = set_globcurrent_fieldset(use_xarray=use_xarray)

pset = parcels.ParticleSet(
fieldset,
pclass=ptype[mode],
lon=[25],
lat=[-35],
time=fieldset.U.time[0] - timedelta(days=1).total_seconds(),
)

pset.execute(
parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5)
time=fieldset.U.grid.time[0] - timedelta(days=1).total_seconds(),
)
with pytest.raises(parcels.TimeExtrapolationError):
pset.execute(
parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5)
)


@pytest.mark.parametrize("mode", ["scipy", "jit"])
Expand Down
42 changes: 25 additions & 17 deletions parcels/field.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import collections
import datetime
import math
import warnings
from collections.abc import Iterable
from ctypes import POINTER, Structure, c_float, c_int, pointer
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import dask.array as da
import numpy as np
Expand All @@ -21,7 +20,7 @@
assert_valid_gridindexingtype,
assert_valid_interp_method,
)
from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr
from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr, timedelta_to_float
from parcels.tools.converters import (
Geographic,
GeographicPolar,
Expand Down Expand Up @@ -152,6 +151,7 @@ class Field:

allow_time_extrapolation: bool
time_periodic: TimePeriodic
_cast_data_dtype: type[np.float32] | type[np.float64]

def __init__(
self,
Expand All @@ -165,16 +165,16 @@ def __init__(
mesh: Mesh = "flat",
timestamps=None,
fieldtype=None,
transpose=False,
vmin=None,
vmax=None,
cast_data_dtype="float32",
time_origin=None,
transpose: bool = False,
vmin: float | None = None,
vmax: float | None = None,
cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32",
time_origin: TimeConverter | None = None,
interp_method: InterpMethod = "linear",
allow_time_extrapolation: bool | None = None,
time_periodic: TimePeriodic = False,
gridindexingtype: GridIndexingType = "nemo",
to_write=False,
to_write: bool = False,
**kwargs,
):
if kwargs.get("netcdf_decodewarning") is not None:
Expand Down Expand Up @@ -250,8 +250,8 @@ def __init__(
"Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object."
)
if self.time_periodic is not False:
if isinstance(self.time_periodic, datetime.timedelta):
self.time_periodic = self.time_periodic.total_seconds()
self.time_periodic = timedelta_to_float(self.time_periodic)

if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic):
if self.grid.time[-1] - self.grid.time[0] > self.time_periodic:
raise ValueError("Time series provided is longer than the time_periodic parameter")
Expand All @@ -261,11 +261,19 @@ def __init__(

self.vmin = vmin
self.vmax = vmax
self._cast_data_dtype = cast_data_dtype
if self.cast_data_dtype == "float32":
self._cast_data_dtype = np.float32
elif self.cast_data_dtype == "float64":
self._cast_data_dtype = np.float64

match cast_data_dtype:
case "float32":
self._cast_data_dtype = np.float32
case "float64":
self._cast_data_dtype = np.float64
case _:
self._cast_data_dtype = cast_data_dtype

if self.cast_data_dtype not in [np.float32, np.float64]:
raise ValueError(
f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'"
)

if not self.grid.defer_load:
self.data = self._reshape(self.data, transpose)
Expand Down Expand Up @@ -803,7 +811,7 @@ def from_xarray(
lat = da[dimensions["lat"]].values

time_origin = TimeConverter(time[0])
time = time_origin.reltime(time)
time = time_origin.reltime(time) # type: ignore[assignment]

grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
kwargs["time_periodic"] = time_periodic
Expand Down
6 changes: 3 additions & 3 deletions parcels/fieldfilebuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def close(self):
self.chunk_mapping = None

@classmethod
def add_to_dimension_name_map_global(self, name_map):
def add_to_dimension_name_map_global(cls, name_map):
"""
[externally callable]
This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to
Expand All @@ -406,9 +406,9 @@ def add_to_dimension_name_map_global(self, name_map):
for pcls_dim_name in name_map.keys():
if isinstance(name_map[pcls_dim_name], list):
for nc_dim_name in name_map[pcls_dim_name]:
self._static_name_maps[pcls_dim_name].append(nc_dim_name)
cls._static_name_maps[pcls_dim_name].append(nc_dim_name)
elif isinstance(name_map[pcls_dim_name], str):
self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])
cls._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])

def add_to_dimension_name_map(self, name_map):
"""
Expand Down
4 changes: 2 additions & 2 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def check_velocityfields(U, V, W):

@classmethod
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def parse_wildcards(self, *args, **kwargs):
return self._parse_wildcards(*args, **kwargs)
def parse_wildcards(cls, *args, **kwargs):
return cls._parse_wildcards(*args, **kwargs)

@classmethod
def _parse_wildcards(cls, paths, filenames, var):
Expand Down
9 changes: 5 additions & 4 deletions parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(
self.funccode = funccode
self.py_ast = py_ast
self.dyn_srcs = []
self.static_srcs = []
self.src_file = None
self.lib_file = None
self.log_file = None
Expand Down Expand Up @@ -562,9 +561,11 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs):
def cleanup_remove_files(lib_file, all_files_array, delete_cfiles):
if lib_file is not None:
if os.path.isfile(lib_file): # and delete_cfiles
[os.remove(s) for s in [lib_file] if os.path is not None and os.path.exists(s)]
if delete_cfiles and len(all_files_array) > 0:
[os.remove(s) for s in all_files_array if os.path is not None and os.path.exists(s)]
os.remove(lib_file)
if delete_cfiles:
for s in all_files_array:
if os.path.exists(s):
os.remove(s)

@staticmethod
def cleanup_unload_lib(lib):
Expand Down
6 changes: 3 additions & 3 deletions parcels/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def __del__(self):

def __repr__(self):
time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}"
str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
for var in vars(type(self)):
if var in ["lon_nextloop", "lat_nextloop", "depth_nextloop", "time_nextloop"]:
continue
if type(getattr(type(self), var)) is Variable and getattr(type(self), var).to_write is True:
str += f"{var}={getattr(self, var):f}, "
return str + f"time={time_string})"
p_string += f"{var}={getattr(self, var):f}, "
return p_string + f"time={time_string})"

@classmethod
def add_variable(cls, var, *args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def getPType(self):

def __repr__(self):
time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}"
str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
for var in self._pcoll.ptype.variables:
if var.name in [
"lon_nextloop",
Expand All @@ -470,8 +470,8 @@ def __repr__(self):
]: # TODO check if time_nextloop is needed (or can work with time-dt?)
continue
if var.to_write is not False and var.name not in ["id", "lon", "lat", "depth", "time"]:
str += f"{var.name}={getattr(self, var.name):f}, "
return str + f"time={time_string})"
p_string += f"{var.name}={getattr(self, var.name):f}, "
return p_string + f"time={time_string})"

def delete(self):
"""Signal the particle for deletion."""
Expand Down
18 changes: 9 additions & 9 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import parcels
from parcels._compat import MPI
from parcels.tools._helpers import default_repr, deprecated, deprecated_made_private
from parcels.tools._helpers import default_repr, deprecated, deprecated_made_private, timedelta_to_float
from parcels.tools.warnings import FileWarning

__all__ = ["ParticleFile"]
Expand Down Expand Up @@ -48,7 +48,7 @@ class ParticleFile:
"""

def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
self._outputdt = timedelta_to_float(outputdt)
self._chunks = chunks
self._particleset = particleset
self._parcels_mesh = "spherical"
Expand Down Expand Up @@ -263,7 +263,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis):
Z.append(a, axis=axis)
zarr.consolidate_metadata(store)

def write(self, pset, time, indices=None):
def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None):
"""Write all data from one time step to the zarr file,
before the particle locations are updated.
Expand All @@ -274,7 +274,7 @@ def write(self, pset, time, indices=None):
time :
Time at which to write ParticleSet
"""
time = time.total_seconds() if isinstance(time, timedelta) else time
time = timedelta_to_float(time) if time is not None else None

if pset.particledata._ncount == 0:
warnings.warn(
Expand Down Expand Up @@ -305,18 +305,18 @@ def write(self, pset, time, indices=None):
if self.create_new_zarrfile:
if self.chunks is None:
self._chunks = (len(ids), 1)
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
warnings.warn(
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
f"a significant slowdown in Parcels when many calls to repeatdt. "
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
FileWarning,
stacklevel=2,
)
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
arrsize = (self._maxids, self.chunks[1])
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
else:
arrsize = (len(ids), self.chunks[1])
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
ds = xr.Dataset(
attrs=self.metadata,
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
Expand All @@ -341,7 +341,7 @@ def write(self, pset, time, indices=None):
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
dims = ["trajectory", "obs"]
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
ds.to_zarr(self.fname, mode="w")
self._create_new_zarrfile = False
else:
Expand Down
40 changes: 21 additions & 19 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from parcels.particle import JITParticle, Variable
from parcels.particledata import ParticleData, ParticleDataIterator
from parcels.particlefile import ParticleFile
from parcels.tools._helpers import deprecated, deprecated_made_private, particleset_repr
from parcels.tools._helpers import deprecated, deprecated_made_private, particleset_repr, timedelta_to_float
from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array
from parcels.tools.global_statics import get_package_dir
from parcels.tools.loggers import logger
Expand Down Expand Up @@ -189,12 +189,13 @@ def ArrayClass_init(self, *args, **kwargs):
lon.size == kwargs[kwvar].size
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."

self.repeatdt = repeatdt.total_seconds() if isinstance(repeatdt, timedelta) else repeatdt
self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None

if self.repeatdt:
if self.repeatdt <= 0:
raise "Repeatdt should be > 0"
raise ValueError("Repeatdt should be > 0")
if time[0] and not np.allclose(time, time[0]):
raise "All Particle.time should be the same when repeatdt is not None"
raise ValueError("All Particle.time should be the same when repeatdt is not None")
self._repeatpclass = pclass
self._repeatkwargs = kwargs
self._repeatkwargs.pop("partition_function", None)
Expand Down Expand Up @@ -986,13 +987,13 @@ def execute(
pyfunc=AdvectionRK4,
pyfunc_inter=None,
endtime=None,
runtime=None,
dt=1.0,
runtime: float | timedelta | np.timedelta64 | None = None,
dt: float | timedelta | np.timedelta64 = 1.0,
output_file=None,
verbose_progress=True,
postIterationCallbacks=None,
callbackdt=None,
delete_cfiles=True,
callbackdt: float | timedelta | np.timedelta64 | None = None,
delete_cfiles: bool = True,
):
"""Execute a given kernel function over the particle set for multiple timesteps.
Expand Down Expand Up @@ -1072,22 +1073,23 @@ def execute(
if self.time_origin.calendar is None:
raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double")
endtime = self.time_origin.reltime(endtime)
if isinstance(runtime, timedelta):
runtime = runtime.total_seconds()
if isinstance(dt, timedelta):
dt = dt.total_seconds()

if runtime is not None:
runtime = timedelta_to_float(runtime)

dt = timedelta_to_float(dt)

if abs(dt) <= 1e-6:
raise ValueError("Time step dt is too small")
if (dt * 1e6) % 1 != 0:
raise ValueError("Output interval should not have finer precision than 1e-6 s")
outputdt = output_file.outputdt if output_file else np.inf
if isinstance(outputdt, timedelta):
outputdt = outputdt.total_seconds()
if outputdt is not None:
outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf

if np.isfinite(outputdt):
_warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"])

if isinstance(callbackdt, timedelta):
callbackdt = callbackdt.total_seconds()
if callbackdt is not None:
callbackdt = timedelta_to_float(callbackdt)

assert runtime is None or runtime >= 0, "runtime must be positive"
assert outputdt is None or outputdt >= 0, "outputdt must be positive"
Expand Down Expand Up @@ -1240,7 +1242,7 @@ def execute(

def _warn_outputdt_release_desync(outputdt: float, release_times: Iterable[float]):
"""Gives the user a warning if the release time isn't a multiple of outputdt."""
if any(t % outputdt != 0 for t in release_times):
if any((np.isfinite(t) and t % outputdt != 0) for t in release_times):
warnings.warn(
"Some of the particles have a start time that is not a multiple of outputdt. "
"This could cause the first output to be at a different time than expected.",
Expand Down
12 changes: 12 additions & 0 deletions parcels/tools/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import textwrap
import warnings
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
from parcels import Field, FieldSet, ParticleSet

Expand Down Expand Up @@ -134,3 +137,12 @@ def fieldset_repr(fieldset: FieldSet) -> str:

def default_repr(obj: Any):
return object.__repr__(obj)


def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float:
"""Convert a timedelta to a float in seconds."""
if isinstance(dt, timedelta):
return dt.total_seconds()
if isinstance(dt, np.timedelta64):
return float(dt / np.timedelta64(1, "s"))
return float(dt)
Loading

0 comments on commit d4ea6a2

Please sign in to comment.