Skip to content

Commit

Permalink
refactor(mesh): rename points -> cells
Browse files Browse the repository at this point in the history
  • Loading branch information
swapneelap committed Nov 17, 2023
1 parent a828938 commit 52ca76c
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
6 changes: 3 additions & 3 deletions discretisedfield/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3565,7 +3565,7 @@ def _hv_key_dims(self):
key_dims = {
dim: hv_key_dim(coords, unit)
for dim, unit in zip(self.mesh.region.dims, self.mesh.region.units)
if len(coords := getattr(self.mesh.points, dim)) > 1
if len(coords := getattr(self.mesh.cells, dim)) > 1
}
if self.nvdim > 1:
key_dims["vdims"] = hv_key_dim(self.vdims, "")
Expand Down Expand Up @@ -4065,7 +4065,7 @@ def to_xarray(self, name="field", unit=None):

axes = self.mesh.region.dims

data_array_coords = {axis: getattr(self.mesh.points, axis) for axis in axes}
data_array_coords = {axis: getattr(self.mesh.cells, axis) for axis in axes}

geo_units_dict = dict(zip(axes, self.mesh.region.units))

Expand Down Expand Up @@ -4314,7 +4314,7 @@ def _(val, mesh, nvdim, dtype):
value = (
val.to_xarray()
.sel(
**{dim: getattr(mesh.points, dim) for dim in mesh.region.dims},
**{dim: getattr(mesh.cells, dim) for dim in mesh.region.dims},
method="nearest",
)
.data
Expand Down
12 changes: 6 additions & 6 deletions discretisedfield/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def __iter__(self):
yield from map(self.index2point, self.indices)

@property
def points(self):
def cells(self):
"""Midpoints of the cells of the mesh along the three directions.
This method returns a named tuple containing three numpy arrays with
Expand All @@ -512,13 +512,13 @@ def points(self):
>>> cell = (2, 1, 1)
>>> mesh = df.Mesh(region=df.Region(p1=p1, p2=p2), cell=cell)
...
>>> mesh.points.x
>>> mesh.cells.x
array([1., 3., 5., 7., 9.])
"""
points = collections.namedtuple("points", self.region.dims)
cells = collections.namedtuple("cells", self.region.dims)

return points(
return cells(
*(
np.linspace(pmin + cell / 2, pmax - cell / 2, n)
for pmin, pmax, cell, n in zip(
Expand Down Expand Up @@ -2161,8 +2161,8 @@ def coordinate_field(self):
vdim_mapping=dict(zip(self.region.dims, self.region.dims)),
)
for i, dim in enumerate(self.region.dims):
points = self.points # avoid re-computing points
field.array[..., i] = getattr(points, dim).reshape(
cells = self.cells # avoid re-computing cells
field.array[..., i] = getattr(cells, dim).reshape(
tuple(self.n[i] if i == j else 1 for j in range(self.region.ndim))
)

Expand Down
8 changes: 4 additions & 4 deletions discretisedfield/plotting/mpl_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,8 @@ def vector(

multiplier = self._setup_multiplier(multiplier)

points1 = self.field.mesh.points[0] / multiplier
points2 = self.field.mesh.points[1] / multiplier
points1 = self.field.mesh.cells[0] / multiplier
points2 = self.field.mesh.cells[1] / multiplier

values = self.field.array.copy()
self._filter_values(self.field._valid_as_field, values)
Expand Down Expand Up @@ -815,8 +815,8 @@ def contour(

multiplier = self._setup_multiplier(multiplier)

points1 = self.field.mesh.points[0] / multiplier
points2 = self.field.mesh.points[1] / multiplier
points1 = self.field.mesh.cells[0] / multiplier
points2 = self.field.mesh.cells[1] / multiplier

values = self.field.array.copy().reshape(self.field.mesh.n)

Expand Down
8 changes: 4 additions & 4 deletions discretisedfield/tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,10 +1854,10 @@ def test_diff_valid():
f = df.Field(mesh, nvdim=1, value=lambda p: p[0] ** 2, valid=valid)

assert np.allclose(f.diff("x").array[:3], 0)
assert np.allclose(f.diff("x").array[3:6, 0], 2 * f.mesh.points[0][3:6])
assert np.allclose(f.diff("x").array[3:6, 0], 2 * f.mesh.cells[0][3:6])
assert np.allclose(f.diff("x").array[6:], 0)
assert np.allclose(
f.diff("x", restrict2valid=False).array[..., 0], 2 * f.mesh.points[0]
f.diff("x", restrict2valid=False).array[..., 0], 2 * f.mesh.cells[0]
)

# 3d mesh
Expand Down Expand Up @@ -3972,7 +3972,7 @@ def test_to_xarray_valid_args_vector(valid_mesh, value, dtype):
assert np.allclose(fxa.attrs["pmax"], f.mesh.region.pmax)
assert np.allclose(fxa.attrs["tolerance_factor"], f.mesh.region.tolerance_factor)
for i in f.mesh.region.dims:
assert np.array_equal(getattr(f.mesh.points, i), fxa[i].values)
assert np.array_equal(getattr(f.mesh.cells, i), fxa[i].values)
assert fxa[i].attrs["units"] == f.mesh.region.units[f.mesh.region.dims.index(i)]
assert all(fxa["vdims"].values == f.vdims)
assert np.array_equal(f.array, fxa.values)
Expand All @@ -3996,7 +3996,7 @@ def test_to_xarray_valid_args_scalar(valid_mesh, value, dtype):
assert np.allclose(fxa.attrs["pmax"], f.mesh.region.pmax)
assert np.allclose(fxa.attrs["tolerance_factor"], f.mesh.region.tolerance_factor)
for i in f.mesh.region.dims:
assert np.array_equal(getattr(f.mesh.points, i), fxa[i].values)
assert np.array_equal(getattr(f.mesh.cells, i), fxa[i].values)
assert fxa[i].attrs["units"] == f.mesh.region.units[f.mesh.region.dims.index(i)]
assert "vdims" not in fxa.dims
assert np.array_equal(f.array.squeeze(axis=-1), fxa.values)
Expand Down
20 changes: 10 additions & 10 deletions discretisedfield/tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,35 +862,35 @@ def test_region2slice():
mesh.region2slices(df.Region(p1=(-1, 3), p2=(3, 5)))


def test_points():
def test_cells():
# 1d example (ndim=1)
p1 = 0
p2 = 10
cell = 2
mesh = df.Mesh(region=df.Region(p1=p1, p2=p2), cell=cell)

assert np.allclose(mesh.points.x, [1.0, 3.0, 5.0, 7.0, 9.0], atol=0)
assert np.allclose(mesh.cells.x, [1.0, 3.0, 5.0, 7.0, 9.0], atol=0)

# 3d example (ndim=3)
p1 = (0, 0, 4)
p2 = (10, 6, 0)
cell = (2, 2, 1)
mesh = df.Mesh(region=df.Region(p1=p1, p2=p2), cell=cell)

assert np.allclose(mesh.points.x, [1.0, 3.0, 5.0, 7.0, 9.0], atol=0)
assert np.allclose(mesh.points.y, [1.0, 3.0, 5.0], atol=0)
assert np.allclose(mesh.points.z, [0.5, 1.5, 2.5, 3.5], atol=0)
assert np.allclose(mesh.cells.x, [1.0, 3.0, 5.0, 7.0, 9.0], atol=0)
assert np.allclose(mesh.cells.y, [1.0, 3.0, 5.0], atol=0)
assert np.allclose(mesh.cells.z, [0.5, 1.5, 2.5, 3.5], atol=0)

# 4d example (ndim=4)
p1 = (0, 0, 4, 4)
p2 = (10, 6, 0, 0)
cell = (2, 2, 1, 1)
mesh = df.Mesh(region=df.Region(p1=p1, p2=p2), cell=cell)

assert np.allclose(mesh.points.x0, [1.0, 3.0, 5.0, 7.0, 9.0], atol=0)
assert np.allclose(mesh.points.x1, [1.0, 3.0, 5.0], atol=0)
assert np.allclose(mesh.points.x2, [0.5, 1.5, 2.5, 3.5], atol=0)
assert np.allclose(mesh.points.x3, [0.5, 1.5, 2.5, 3.5], atol=0)
assert np.allclose(mesh.cells.x0, [1.0, 3.0, 5.0, 7.0, 9.0], atol=0)
assert np.allclose(mesh.cells.x1, [1.0, 3.0, 5.0], atol=0)
assert np.allclose(mesh.cells.x2, [0.5, 1.5, 2.5, 3.5], atol=0)
assert np.allclose(mesh.cells.x3, [0.5, 1.5, 2.5, 3.5], atol=0)


def test_vertices():
Expand Down Expand Up @@ -1616,7 +1616,7 @@ def test_coordinate_field(valid_mesh):
index[valid_mesh.region._dim2index(dim)] = slice(None)
# extra index for vector dimension: vector component along the current direction
index = tuple(index) + (valid_mesh.region._dim2index(dim),)
assert np.allclose(cfield.array[index], getattr(valid_mesh.points, dim), atol=0)
assert np.allclose(cfield.array[index], getattr(valid_mesh.cells, dim), atol=0)


def test_sel_convert_intput():
Expand Down

0 comments on commit 52ca76c

Please sign in to comment.