Skip to content

Commit

Permalink
tests: Add tests referring to commits 3edf20 and 118ec4, which add th…
Browse files Browse the repository at this point in the history
…e increment option in the SparseTimeFunction inject method
  • Loading branch information
fffarias committed May 24, 2022
1 parent e3d57c1 commit 67d3f34
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import sin, floor
from math import sin, floor, prod

import numpy as np
import pytest
Expand Down Expand Up @@ -53,6 +53,18 @@ def time_points(grid, ranges, npoints, name='points', nt=10):
return points


def time_grid_points(grid, name='points', nt=10):
"""Create a SparseTimeFunction field with coordinates
filled in by all grid points"""
npoints = prod(grid.shape)
a = SparseTimeFunction(name=name, grid=grid, npoint=npoints, nt=nt)
dims = tuple([np.linspace(0., 1., d) for d in grid.shape])
for i in range(len(grid.shape)):
a.coordinates.data[:,i] = np.meshgrid(*dims)[i].flatten()

return a


def a(shape=(11, 11)):
grid = Grid(shape=shape)
a = Function(name='a', grid=grid)
Expand Down Expand Up @@ -417,6 +429,31 @@ def test_inject_time_shift(shape, coords, result, npoints=19):
assert np.allclose(a.data[indices], result, rtol=1.e-5)


@pytest.mark.parametrize('shape, result, increment', [
((10, 10), 1., False),
((10, 10), 5., True),
((10, 10, 10), 1., False),
((10, 10, 10), 5., True)
])
def test_inject_time_increment(shape, result, increment):
"""Test the increment option in the SparseTimeFunction's
injection method. The increment=False option is
expected to work only at points located on the grid,
where no interpolation needed.
"""
a = unit_box_time(shape=shape)
a.data[:] = 0.
p = time_grid_points(a.grid, name='points', nt=10)

expr = p.inject(a, Float(1.), increment=increment)

Operator(expr)(a=a)

assert np.allclose(a.data, result*np.ones(a.grid.shape), rtol=1.e-5)




@pytest.mark.parametrize('shape, coords, result', [
((11, 11), [(.05, .95), (.45, .45)], 1.),
((11, 11, 11), [(.05, .95), (.45, .45), (.45, .45)], 0.5)
Expand Down

0 comments on commit 67d3f34

Please sign in to comment.