From 6b74a68eab73c291c2a73cb8b3ea47d96ba1f883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tuomas=20K=C3=A4rn=C3=A4?= Date: Mon, 18 Nov 2024 12:29:41 +0200 Subject: [PATCH] Add arange default values (#20) --- examples/shallow_water.py | 20 +++++++++----------- examples/wave_equation.py | 10 ++++------ sharpy/__init__.py | 20 ++++++++++++++++---- test/test_create.py | 39 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 65 insertions(+), 24 deletions(-) diff --git a/examples/shallow_water.py b/examples/shallow_water.py index a5798a4..35096bb 100644 --- a/examples/shallow_water.py +++ b/examples/shallow_water.py @@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode): def transpose(a): return np.permute_dims(a, [1, 0]) - all_axes = [0, 1] init(False) elif backend == "numpy": @@ -76,7 +75,6 @@ def transpose(a): transpose = np.transpose fini = sync = lambda x=None: None - all_axes = None else: raise ValueError(f'Unknown backend: "{backend}"') @@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly): # set bathymetry h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly) # steady state potential energy - pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny + pe_offset = 0.5 * g * float(np.sum(h**2.0)) / nx / ny # compute time step alpha = 0.5 - h_max = float(np.max(h, all_axes)) + h_max = float(np.max(h)) c = (g * h_max) ** 0.5 dt = alpha * dx / c dt = t_export / int(math.ceil(t_export / dt)) @@ -344,14 +342,14 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2): t = i * dt if t >= next_t_export - 1e-8: - _elev_max = np.max(e, all_axes) - _u_max = np.max(u, all_axes) - _q_max = np.max(q, all_axes) - _total_v = np.sum(e + h, all_axes) + _elev_max = np.max(e) + _u_max = np.max(u) + _q_max = np.max(q) + _total_v = np.sum(e + h) # potential energy _pe = 0.5 * g * (e + h) * (e - h) + pe_offset - _total_pe = np.sum(_pe, all_axes) + _total_pe = np.sum(_pe) # kinetic energy u2 = u * u @@ -359,7 +357,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2): u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :]) v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1]) _ke = 0.5 * (u2_at_t + v2_at_t) * (e + h) - _total_ke = np.sum(_ke, all_axes) + _total_ke = np.sum(_ke) total_pe = float(_total_pe) * dx * dy total_ke = float(_total_ke) * dx * dy @@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2): 2 ] err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly - err_L2 = math.sqrt(float(np.sum(err2, all_axes))) + err_L2 = math.sqrt(float(np.sum(err2))) info(f"L2 error: {err_L2:7.15e}") if nx < 128 or ny < 128: diff --git a/examples/wave_equation.py b/examples/wave_equation.py index defd2c6..7561bac 100644 --- a/examples/wave_equation.py +++ b/examples/wave_equation.py @@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode): def transpose(a): return np.permute_dims(a, [1, 0]) - all_axes = [0, 1] init(False) elif backend == "numpy": @@ -76,7 +75,6 @@ def transpose(a): transpose = np.transpose fini = sync = lambda x=None: None - all_axes = None else: raise ValueError(f'Unknown backend: "{backend}"') @@ -240,9 +238,9 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2): t = i * dt if t >= next_t_export - 1e-8: - _elev_max = np.max(e, all_axes) - _u_max = np.max(u, all_axes) - _total_v = np.sum(e + h, all_axes) + _elev_max = np.max(e) + _u_max = np.max(u) + _total_v = np.sum(e + h) elev_max = float(_elev_max) u_max = float(_u_max) @@ -279,7 +277,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2): e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly) err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly - err_L2 = math.sqrt(float(np.sum(err2, all_axes))) + err_L2 = math.sqrt(float(np.sum(err2))) info(f"L2 error: {err_L2:7.5e}") if nx == 128 and ny == 128 and not benchmark_mode: diff --git a/sharpy/__init__.py b/sharpy/__init__.py index cb1cc47..2b18924 100644 --- a/sharpy/__init__.py +++ b/sharpy/__init__.py @@ -96,6 +96,22 @@ def _validate_device(device): raise ValueError(f"Invalid device string: {device}") +def arange(start, /, end=None, step=1, dtype=int64, device="", team=1): + if end is None: + end = start + start = 0 + assert step != 0, "step cannot be zero" + if (end - start) * step < 0: + # invalid range, return empty array + start = end = 0 + step = 1 + return ndarray( + _csp.Creator.arange( + start, end, step, dtype, _validate_device(device), team + ) + ) + + for func in api.api_categories["Creator"]: FUNC = func.upper() if func == "full": @@ -114,10 +130,6 @@ def _validate_device(device): exec( f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))" ) - elif func == "arange": - exec( - f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))" - ) elif func == "linspace": exec( f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))" diff --git a/test/test_create.py b/test/test_create.py index e73049e..f782a97 100644 --- a/test/test_create.py +++ b/test/test_create.py @@ -26,6 +26,41 @@ def creator(request): return request.param[0], request.param[1] +def test_arange(): + n = 10 + a = sp.arange(0, n, 1, dtype=sp.int32, device=device) + assert tuple(a.shape) == (n,) + assert numpy.allclose(sp.to_numpy(a), list(range(n))) + + +def test_arange2(): + n = 10 + a = sp.arange(0, n, dtype=sp.int32, device=device) + assert tuple(a.shape) == (n,) + assert numpy.allclose(sp.to_numpy(a), list(range(n))) + + +def test_arange3(): + n = 10 + a = sp.arange(n, device=device) + assert tuple(a.shape) == (n,) + assert numpy.allclose(sp.to_numpy(a), list(range(n))) + + +def test_arange_empty(): + n = 10 + a = sp.arange(n, 0, device=device) + assert tuple(a.shape) == (0,) + assert numpy.allclose(sp.to_numpy(a), list()) + + +def test_arange_empty2(): + n = 10 + a = sp.arange(0, n, -1, device=device) + assert tuple(a.shape) == (0,) + assert numpy.allclose(sp.to_numpy(a), list()) + + def test_create_datatypes(creator, datatype): shape = (6, 4) func, expected_value = creator @@ -67,9 +102,7 @@ def test_full_invalid_shape(): sp.full(shape, value, dtype=datatype, device=device) -@pytest.mark.parametrize( - "start,end,step", [(0, 10, -1), (0, -10, 1), (0, 99999999999999999999, 1)] -) +@pytest.mark.parametrize("start,end,step", [(0, 99999999999999999999, 1)]) def tests_arange_invalid(start, end, step): with pytest.raises(TypeError): sp.arange(start, end, step, dtype=sp.int32, device=device)