Skip to content

Commit

Permalink
Merge branch 'master' into test_DeprecationWarning
Browse files Browse the repository at this point in the history
  • Loading branch information
kzqureshi committed Apr 16, 2024
2 parents d27e50f + 7e306b4 commit a1f3e9f
Show file tree
Hide file tree
Showing 18 changed files with 176 additions and 182 deletions.
30 changes: 11 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
default_language_version:
python: python3.8

exclude: 'dev'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand All @@ -12,26 +14,16 @@ repos:
# args: [--markdown-linebreak-ext=md]
# exclude: 'discretisedfield/tests/test_sample/.*'

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
hooks:
- id: nbqa-isort # isort inside Jupyter notebooks

- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies: [flake8-rst-docstrings] #, flake8-docstrings]

- repo: https://github.com/psf/black
rev: 23.9.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: black-jupyter
# Run the linter.
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]

# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0
Expand Down
17 changes: 8 additions & 9 deletions discretisedfield/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""Finite-difference fields."""
import importlib.metadata
import os
import pathlib

import matplotlib.pyplot as plt
import pytest

from . import tools
from .field import Field
from .field_rotator import FieldRotator
from .interact import interact
from .line import Line
from .mesh import Mesh
from .operators import integrate
from .region import Region
from . import tools # noqa: F401
from .field import Field # noqa: F401
from .field_rotator import FieldRotator # noqa: F401
from .interact import interact # noqa: F401
from .line import Line # noqa: F401
from .mesh import Mesh # noqa: F401
from .operators import integrate # noqa: F401
from .region import Region # noqa: F401

# Enable default plotting style.
plt.style.use(pathlib.Path(__file__).parent / "plotting" / "plotting-style.mplstyle")
Expand Down
171 changes: 83 additions & 88 deletions discretisedfield/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def update_field_values(self, value):
.. seealso:: :py:func:`~discretisedfield.Field.array`
"""
self.array = _as_array(value, self.mesh, self.nvdim, dtype=self.dtype)
self.array = self._as_array(value, self.mesh, self.nvdim, dtype=self.dtype)

@property
def vdims(self):
Expand Down Expand Up @@ -457,7 +457,7 @@ def array(self):

@array.setter
def array(self, val):
self._array = _as_array(val, self.mesh, self.nvdim, dtype=self.dtype)
self._array = self._as_array(val, self.mesh, self.nvdim, dtype=self.dtype)

@property
def norm(self):
Expand Down Expand Up @@ -539,7 +539,7 @@ def norm(self, val):
out=np.zeros_like(self.array),
where=self.norm.array != 0.0,
)
self.array *= _as_array(val, self.mesh, nvdim=1, dtype=None)
self.array *= self._as_array(val, self.mesh, nvdim=1, dtype=None)

@property
def valid(self):
Expand All @@ -561,10 +561,10 @@ def valid(self, valid):
valid = ~np.isclose(self.norm.array, 0)
else:
valid = True
# Using _as_array creates an array with shape (*mesh.n, 1).
# Using self._as_array creates an array with shape (*mesh.n, 1).
# We only want a shape of mesh.n so we can directly use it
# to index field.array i.e. field.array[field.valid].
self._valid = _as_array(valid, self.mesh, nvdim=1, dtype=bool)[..., 0]
self._valid = self._as_array(valid, self.mesh, nvdim=1, dtype=bool)[..., 0]

@property
def _valid_as_field(self):
Expand Down Expand Up @@ -3577,7 +3577,7 @@ def _hv_key_dims(self):
key_dims = {
dim: hv_key_dim(coords, unit)
for dim, unit in zip(self.mesh.region.dims, self.mesh.region.units)
if len(coords := getattr(self.mesh.points, dim)) > 1
if len(coords := getattr(self.mesh.cells, dim)) > 1
}
if self.nvdim > 1:
key_dims["vdims"] = hv_key_dim(self.vdims, "")
Expand Down Expand Up @@ -4057,14 +4057,12 @@ def to_xarray(self, name="field", unit=None):
>>> xa = field.to_xarray()
>>> xa
<xarray.DataArray 'field' (x: 10, y: 10, z: 10, vdims: 3)>
...
<xarray.DataArray 'field' (x: 10, y: 10, z: 10, vdims: 3)>...
3. Select values of `x` component
>>> xa.sel(vdims='x')
<xarray.DataArray 'field' (x: 10, y: 10, z: 10)>
...
<xarray.DataArray 'field' (x: 10, y: 10, z: 10)>...
"""
if not isinstance(name, str):
Expand All @@ -4077,7 +4075,7 @@ def to_xarray(self, name="field", unit=None):

axes = self.mesh.region.dims

data_array_coords = {axis: getattr(self.mesh.points, axis) for axis in axes}
data_array_coords = {axis: getattr(self.mesh.cells, axis) for axis in axes}

geo_units_dict = dict(zip(axes, self.mesh.region.units))

Expand Down Expand Up @@ -4186,8 +4184,7 @@ def from_xarray(cls, xa):
... p2=[21., 21., 21.],
... nvdim=3),)
>>> xa
<xarray.DataArray 'mag' (x: 20, y: 20, z: 20, vdims: 3)>
...
<xarray.DataArray 'mag' (x: 20, y: 20, z: 20, vdims: 3)>...
2. Create Field from DataArray
Expand Down Expand Up @@ -4270,53 +4267,87 @@ def from_xarray(cls, xa):
mesh=mesh, nvdim=nvdim, value=val, vdims=vdims, dtype=xa.values.dtype
)

@functools.singledispatchmethod
def _as_array(self, val, mesh, nvdim, dtype):
raise TypeError(f"Unsupported type {type(val)}.")

@functools.singledispatch
def _as_array(val, mesh, nvdim, dtype):
raise TypeError(f"Unsupported type {type(val)}.")


# to avoid str being interpreted as iterable
@_as_array.register(str)
def _(val, mesh, nvdim, dtype):
raise TypeError(f"Unsupported type {type(val)}.")
# to avoid str being interpreted as iterable
@_as_array.register(str)
def _(self, val, mesh, nvdim, dtype):
raise TypeError(f"Unsupported type {type(val)}.")

@_as_array.register(numbers.Complex)
@_as_array.register(collections.abc.Iterable)
def _(self, val, mesh, nvdim, dtype):
if isinstance(val, numbers.Complex) and nvdim > 1 and val != 0:
raise ValueError(
f"Wrong dimension 1 provided for value; expected dimension is {nvdim}"
)

@_as_array.register(numbers.Complex)
@_as_array.register(collections.abc.Iterable)
def _(val, mesh, nvdim, dtype):
if isinstance(val, numbers.Complex) and nvdim > 1 and val != 0:
raise ValueError(
f"Wrong dimension 1 provided for value; expected dimension is {nvdim}"
if isinstance(val, collections.abc.Iterable):
if nvdim == 1 and np.array_equal(np.shape(val), mesh.n):
return np.expand_dims(val, axis=-1)
elif np.shape(val)[-1] != nvdim:
raise ValueError(
f"Wrong dimension {len(val)} provided for value; expected dimension"
f" is {nvdim}."
)
dtype = dtype or max(np.asarray(val).dtype, np.float64)
return np.full((*mesh.n, nvdim), val, dtype=dtype)

@_as_array.register(collections.abc.Callable)
def _(self, val, mesh, nvdim, dtype):
# will only be called on user input
# dtype must be specified by the user for complex values
array = np.empty((*mesh.n, nvdim), dtype=dtype)
for index, point in zip(mesh.indices, mesh):
# Conversion to array and reshaping is required for numpy >= 1.24
# and for certain inputs, e.g. a tuple of numpy arrays which can e.g. occur
# for 1d vector fields.
array[index] = np.asarray(val(point)).reshape(nvdim)
return array

@_as_array.register(dict)
def _(self, val, mesh, nvdim, dtype):
# will only be called on user input
# dtype must be specified by the user for complex values
dtype = dtype or np.float64
fill_value = (
val["default"]
if "default" in val and not callable(val["default"])
else np.nan
)
array = np.full((*mesh.n, nvdim), fill_value, dtype=dtype)

if isinstance(val, collections.abc.Iterable):
if nvdim == 1 and np.array_equal(np.shape(val), mesh.n):
return np.expand_dims(val, axis=-1)
elif np.shape(val)[-1] != nvdim:
raise ValueError(
f"Wrong dimension {len(val)} provided for value; expected dimension is"
f" {nvdim}."
)
dtype = dtype or max(np.asarray(val).dtype, np.float64)
return np.full((*mesh.n, nvdim), val, dtype=dtype)
for subregion in reversed(mesh.subregions.keys()):
# subregions can overlap, first subregion takes precedence
try:
submesh = mesh[subregion]
subval = val[subregion]
except KeyError:
continue # subregion not in val when implicitly set via "default"
else:
slices = mesh.region2slices(submesh.region)
array[slices] = self._as_array(subval, submesh, nvdim, dtype)

if np.any(np.isnan(array)):
# not all subregion keys specified and 'default' is missing or callable
if "default" not in val:
raise KeyError(
"Key 'default' required if not all subregion keys are specified."
)
subval = val["default"]
for idx in np.argwhere(np.isnan(array[..., 0])):
# only spatial indices required -> array[..., 0]
# conversion to array and reshaping similar to "callable" implementation
array[idx] = np.asarray(subval(mesh.index2point(idx))).reshape(nvdim)

@_as_array.register(collections.abc.Callable)
def _(val, mesh, nvdim, dtype):
# will only be called on user input
# dtype must be specified by the user for complex values
array = np.empty((*mesh.n, nvdim), dtype=dtype)
for index, point in zip(mesh.indices, mesh):
# Conversion to array and reshaping is required for numpy >= 1.24
# and for certain inputs, e.g. a tuple of numpy arrays which can e.g. occur
# for 1d vector fields.
array[index] = np.asarray(val(point)).reshape(nvdim)
return array
return array


@_as_array.register(Field)
def _(val, mesh, nvdim, dtype):
# We cannot register to self (or df.Field) inside the class
@Field._as_array.register(Field)
def _(self, val, mesh, nvdim, dtype):
if mesh.region not in val.mesh.region:
raise ValueError(
f"{val.mesh.region} of the provided field does not "
Expand All @@ -4326,7 +4357,7 @@ def _(val, mesh, nvdim, dtype):
value = (
val.to_xarray()
.sel(
**{dim: getattr(mesh.points, dim) for dim in mesh.region.dims},
**{dim: getattr(mesh.cells, dim) for dim in mesh.region.dims},
method="nearest",
)
.data
Expand All @@ -4335,39 +4366,3 @@ def _(val, mesh, nvdim, dtype):
# xarray dataarrays for scalar data are three dimensional
return value.reshape(*mesh.n, -1)
return value


@_as_array.register(dict)
def _(val, mesh, nvdim, dtype):
# will only be called on user input
# dtype must be specified by the user for complex values
dtype = dtype or np.float64
fill_value = (
val["default"] if "default" in val and not callable(val["default"]) else np.nan
)
array = np.full((*mesh.n, nvdim), fill_value, dtype=dtype)

for subregion in reversed(mesh.subregions.keys()):
# subregions can overlap, first subregion takes precedence
try:
submesh = mesh[subregion]
subval = val[subregion]
except KeyError:
continue # subregion not in val when implicitly set via "default"
else:
slices = mesh.region2slices(submesh.region)
array[slices] = _as_array(subval, submesh, nvdim, dtype)

if np.any(np.isnan(array)):
# not all subregion keys specified and 'default' is missing or callable
if "default" not in val:
raise KeyError(
"Key 'default' required if not all subregion keys are specified."
)
subval = val["default"]
for idx in np.argwhere(np.isnan(array[..., 0])):
# only spatial indices required -> array[..., 0]
# conversion to array and reshaping similar to "callable" implementation
array[idx] = np.asarray(subval(mesh.index2point(idx))).reshape(nvdim)

return array
1 change: 0 additions & 1 deletion discretisedfield/html/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re

import jinja2
Expand Down
8 changes: 4 additions & 4 deletions discretisedfield/io/ovf.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def _from_ovf(cls, filename):
ovf_v2 = b"2.0" in next(f)
for line in f:
line = line.decode("utf-8")
if line.startswith("# Begin: Data"):
mode = line.split()[3]
if mode == "Binary":
if line.lower().startswith("# begin: data"):
mode = line.split()[3].lower()
if mode == "binary":
nbytes = int(line.split()[-1])
break
information = line[1:].split(":") # remove leading `#`
Expand All @@ -170,7 +170,7 @@ def _from_ovf(cls, filename):
nodes = math.prod(int(header[f"{key}nodes"]) for key in "xyz")

# >>> READ DATA <<<
if mode == "Binary":
if mode == "binary":
# OVF2 uses little-endian and OVF1 uses big-endian
format = f'{"<" if ovf_v2 else ">"}{"d" if nbytes == 8 else "f"}'

Expand Down
Loading

0 comments on commit a1f3e9f

Please sign in to comment.