Skip to content

Commit

Permalink
Add arange default values (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarna authored Nov 18, 2024
1 parent 9506c28 commit 6b74a68
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 24 deletions.
20 changes: 9 additions & 11 deletions examples/shallow_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
def transpose(a):
return np.permute_dims(a, [1, 0])

all_axes = [0, 1]
init(False)

elif backend == "numpy":
Expand All @@ -76,7 +75,6 @@ def transpose(a):
transpose = np.transpose

fini = sync = lambda x=None: None
all_axes = None
else:
raise ValueError(f'Unknown backend: "{backend}"')

Expand Down Expand Up @@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
# set bathymetry
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
# steady state potential energy
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
pe_offset = 0.5 * g * float(np.sum(h**2.0)) / nx / ny

# compute time step
alpha = 0.5
h_max = float(np.max(h, all_axes))
h_max = float(np.max(h))
c = (g * h_max) ** 0.5
dt = alpha * dx / c
dt = t_export / int(math.ceil(t_export / dt))
Expand Down Expand Up @@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
t = i * dt

if t >= next_t_export - 1e-8:
_elev_max = np.max(e, all_axes)
_u_max = np.max(u, all_axes)
_q_max = np.max(q, all_axes)
_total_v = np.sum(e + h, all_axes)
_elev_max = np.max(e)
_u_max = np.max(u)
_q_max = np.max(q)
_total_v = np.sum(e + h)

# potential energy
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
_total_pe = np.sum(_pe, all_axes)
_total_pe = np.sum(_pe)

# kinetic energy
u2 = u * u
v2 = v * v
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
_total_ke = np.sum(_ke, all_axes)
_total_ke = np.sum(_ke)

total_pe = float(_total_pe) * dx * dy
total_ke = float(_total_ke) * dx * dy
Expand Down Expand Up @@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
2
]
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
err_L2 = math.sqrt(float(np.sum(err2)))
info(f"L2 error: {err_L2:7.15e}")

if nx < 128 or ny < 128:
Expand Down
10 changes: 4 additions & 6 deletions examples/wave_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
def transpose(a):
return np.permute_dims(a, [1, 0])

all_axes = [0, 1]
init(False)

elif backend == "numpy":
Expand All @@ -76,7 +75,6 @@ def transpose(a):
transpose = np.transpose

fini = sync = lambda x=None: None
all_axes = None
else:
raise ValueError(f'Unknown backend: "{backend}"')

Expand Down Expand Up @@ -240,9 +238,9 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
t = i * dt

if t >= next_t_export - 1e-8:
_elev_max = np.max(e, all_axes)
_u_max = np.max(u, all_axes)
_total_v = np.sum(e + h, all_axes)
_elev_max = np.max(e)
_u_max = np.max(u)
_total_v = np.sum(e + h)

elev_max = float(_elev_max)
u_max = float(_u_max)
Expand Down Expand Up @@ -279,7 +277,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):

e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
err_L2 = math.sqrt(float(np.sum(err2)))
info(f"L2 error: {err_L2:7.5e}")

if nx == 128 and ny == 128 and not benchmark_mode:
Expand Down
20 changes: 16 additions & 4 deletions sharpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ def _validate_device(device):
raise ValueError(f"Invalid device string: {device}")


def arange(start, /, end=None, step=1, dtype=int64, device="", team=1):
if end is None:
end = start
start = 0
assert step != 0, "step cannot be zero"
if (end - start) * step < 0:
# invalid range, return empty array
start = end = 0
step = 1
return ndarray(
_csp.Creator.arange(
start, end, step, dtype, _validate_device(device), team
)
)


for func in api.api_categories["Creator"]:
FUNC = func.upper()
if func == "full":
Expand All @@ -114,10 +130,6 @@ def _validate_device(device):
exec(
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))"
)
elif func == "arange":
exec(
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))"
)
elif func == "linspace":
exec(
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"
Expand Down
39 changes: 36 additions & 3 deletions test/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,41 @@ def creator(request):
return request.param[0], request.param[1]


def test_arange():
n = 10
a = sp.arange(0, n, 1, dtype=sp.int32, device=device)
assert tuple(a.shape) == (n,)
assert numpy.allclose(sp.to_numpy(a), list(range(n)))


def test_arange2():
n = 10
a = sp.arange(0, n, dtype=sp.int32, device=device)
assert tuple(a.shape) == (n,)
assert numpy.allclose(sp.to_numpy(a), list(range(n)))


def test_arange3():
n = 10
a = sp.arange(n, device=device)
assert tuple(a.shape) == (n,)
assert numpy.allclose(sp.to_numpy(a), list(range(n)))


def test_arange_empty():
n = 10
a = sp.arange(n, 0, device=device)
assert tuple(a.shape) == (0,)
assert numpy.allclose(sp.to_numpy(a), list())


def test_arange_empty2():
n = 10
a = sp.arange(0, n, -1, device=device)
assert tuple(a.shape) == (0,)
assert numpy.allclose(sp.to_numpy(a), list())


def test_create_datatypes(creator, datatype):
shape = (6, 4)
func, expected_value = creator
Expand Down Expand Up @@ -67,9 +102,7 @@ def test_full_invalid_shape():
sp.full(shape, value, dtype=datatype, device=device)


@pytest.mark.parametrize(
"start,end,step", [(0, 10, -1), (0, -10, 1), (0, 99999999999999999999, 1)]
)
@pytest.mark.parametrize("start,end,step", [(0, 99999999999999999999, 1)])
def tests_arange_invalid(start, end, step):
with pytest.raises(TypeError):
sp.arange(start, end, step, dtype=sp.int32, device=device)

0 comments on commit 6b74a68

Please sign in to comment.