From 6077b0da6fb1af6b9a0d7a79f518957d133752f9 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 26 Dec 2024 15:23:41 -0800 Subject: [PATCH 01/21] Fix tuple indexing in Python <3.9 --- dace/frontend/python/newast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index e625a004a9..33813a8d9d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -5178,7 +5178,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: if isinstance(res, (ast.Constant, NumConstant)): res = res.value elif sys.version_info < (3, 9) and isinstance(s, ast.Index): - res = self._parse_subscript_slice(s.value) + res = self._parse_subscript_slice(s.value, multidim=multidim) elif isinstance(s, ast.Slice): lower = s.lower if isinstance(lower, ast.AST): From 656307710d78f629346fd6ae2a868249b00cb17c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 26 Dec 2024 15:24:32 -0800 Subject: [PATCH 02/21] Correctly compute output shape from advanced and advanced/basic indexing combinations --- dace/frontend/python/newast.py | 120 ++++++++++++++++++++++----------- 1 file changed, 82 insertions(+), 38 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 33813a8d9d..9f07347146 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -40,7 +40,7 @@ # register replacements in oprepo import dace.frontend.python.replacements -from dace.frontend.python.replacements import _sym_type, _broadcast_to +from dace.frontend.python.replacements import _sym_type, _broadcast_to, _broadcast_together # Type hints Size = Union[int, dace.symbolic.symbol] @@ -449,11 +449,10 @@ def add_indirection_subgraph(sdfg: SDFG, for i, r in enumerate(memlet.subset): if i in nonsqz_dims: mapped_rng.append(r) - ind_entry, ind_exit = graph.add_map('indirection', { - '__i%d' % i: '%s:%s+1:%s' % (s, e, t) - for i, (s, e, t) in enumerate(mapped_rng) - }, - debuginfo=pvisitor.current_lineinfo) + ind_entry, ind_exit = graph.add_map( + 'indirection', {'__i%d' % i: '%s:%s+1:%s' % (s, e, t) + for i, (s, e, t) in enumerate(mapped_rng)}, + debuginfo=pvisitor.current_lineinfo) inp_base_path.insert(0, ind_entry) out_base_path.append(ind_exit) @@ -1341,10 +1340,9 @@ def defined(self): result.update(self.sdfg.arrays) # MPI-related stuff - result.update({ - v: self.sdfg.process_grids[v] - for k, v in self.variables.items() if v in self.sdfg.process_grids - }) + result.update( + {v: self.sdfg.process_grids[v] + for k, v in self.variables.items() if v in self.sdfg.process_grids}) try: from mpi4py import MPI result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)}) @@ -2587,14 +2585,11 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): # Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE) langInf = None side_effects = None - if isinstance(node, ast.FunctionDef) and \ - hasattr(node, 'decorator_list') and \ - isinstance(node.decorator_list, list) and \ - len(node.decorator_list) > 0 and \ - hasattr(node.decorator_list[0], 'args') and \ - isinstance(node.decorator_list[0].args, list) and \ - len(node.decorator_list[0].args) > 0 and \ - hasattr(node.decorator_list[0].args[0], 'value'): + if isinstance(node, ast.FunctionDef) and hasattr(node, 'decorator_list') and isinstance( + node.decorator_list, + list) and len(node.decorator_list) > 0 and hasattr(node.decorator_list[0], 'args') and isinstance( + node.decorator_list[0].args, list) and len(node.decorator_list[0].args) > 0 and hasattr( + node.decorator_list[0].args[0], 'value'): langArg = node.decorator_list[0].args[0].value langInf = dtypes.Language[langArg] @@ -3898,10 +3893,10 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no # Map internal SDFG symbols by adding keyword arguments symbols = sdfg.used_symbols(all_symbols=False) try: - mapping = infer_symbols_from_datadescriptor(sdfg, { - k: self.sdfg.arrays[v] - for k, v in args if v in self.sdfg.arrays - }, set(sym.arg for sym in node.keywords if sym.arg in symbols)) + mapping = infer_symbols_from_datadescriptor( + sdfg, {k: self.sdfg.arrays[v] + for k, v in args if v in self.sdfg.arrays}, + set(sym.arg for sym in node.keywords if sym.arg in symbols)) except ValueError as ex: raise DaceSyntaxError(self, node, str(ex)) if len(mapping) == 0: # Default to same-symbol mapping @@ -4772,8 +4767,7 @@ def visit_With(self, node: ast.With, is_async=False): else: name = self.name - tasklet, inputs, outputs, sdfg_inp, sdfg_out = \ - self._parse_tasklet(state, node, name) + tasklet, inputs, outputs, sdfg_inp, sdfg_out = self._parse_tasklet(state, node, name) # Add memlets inputs = {k: (state, v, set()) for k, v in inputs.items()} @@ -5365,6 +5359,65 @@ def make_slice(self, arrname: str, rng: subsets.Range): rnode, wnode, Memlet.simple(array, rng, num_accesses=rng.num_elements(), other_subset_str=other_subset)) return tmp, other_subset + def _create_output_shape_from_advanced_indexing(self, aname: str, expr: MemletExpr) -> List[symbolic.SymbolicType]: + """ + Creates the output shape of a slicing operation with advanced indexing. + + :param aname: The name of the array being sliced. + :param expr: The MemletExpr object representing the slicing operation. + :return: A list of symbolic dimensions representing the output shape. + """ + # The output shape is the shape of all contiguous advanced indexing arrays, after broadcasting with each other + # Start with all basic indexing dimensions, setting advanced indexing dimensions to None + output_shape = [s if i not in expr.arrdims else None for i, s in enumerate(expr.subset.size())] + # If any advanced indexing is found, mark any scalar dimension as advanced indices too + if expr.arrdims: + output_shape = [None if rng[0] == rng[1] else s for s, rng in zip(output_shape, expr.subset.ndrange())] + + # Mark every dimension that starts with None as an advanced indexing "chunk" + advanced_dims = [ + i for i, s in enumerate(output_shape) if s is None and (i == 0 or output_shape[i - 1] is not None) + ] + # If there is more than one contiguous advanced indexing chunk, move all advanced indices to the beginning + prefix_dims = len(advanced_dims) > 1 + if prefix_dims: + output_shape = [None] + [s for s in output_shape if s is not None] + dim_position = 0 + else: + dim_position = advanced_dims[0] + + # Contract contiguous None dimensions that appear multiple times in a row + output_shape = [ + s for i, s in enumerate(output_shape) if s is not None or i == 0 or output_shape[i - 1] is not None + ] + + # Broadcast all advanced indexing expressions together + chunk_shape = None + # Get the advanced indexing expressions + for i, arrname in expr.arrdims.items(): + if isinstance(arrname, str): # Array or constant + if arrname in self.sdfg.arrays: + desc = self.sdfg.arrays[arrname] + elif arrname in self.sdfg.constants: + desc = self.sdfg.constants[arrname] + else: + raise NameError(f'Array "{arrname}" used in indexing "{aname}" not found') + shape = desc.shape + else: # Literal list or tuple, add as constant and use shape + arrname = [v if isinstance(v, Number) else self._parse_value(v) for v in arrname] + carr = numpy.array(arrname, dtype=dtypes.typeclass(int).type) + shape = carr.shape + + if chunk_shape is not None: + chunk_shape, *_ = _broadcast_together(shape, chunk_shape) + else: + chunk_shape = tuple(shape) + + # Replace the advanced indexing dimensions with the broadcasted shape + output_shape = output_shape[:dim_position] + list(chunk_shape) + output_shape[dim_position + 1:] + + return output_shape + def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) -> str: aname = rnode.data idesc = self.sdfg.arrays[aname] @@ -5375,32 +5428,23 @@ def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) raise IndexError('New axes unsupported when array indices are used') # Create output shape dimensions based on the sizes of the arrays - output_shape = None + output_shape = self._create_output_shape_from_advanced_indexing(aname, expr) + # Create constants for array indices constant_indices: Dict[int, str] = {} for i, arrname in expr.arrdims.items(): if isinstance(arrname, str): # Array or constant - if arrname in self.sdfg.arrays: - desc = self.sdfg.arrays[arrname] - elif arrname in self.sdfg.constants: - desc = self.sdfg.constants[arrname] + if arrname in self.sdfg.constants: constant_indices[i] = arrname + elif arrname in self.sdfg.arrays: + pass else: raise NameError(f'Array "{arrname}" used in indexing "{aname}" not found') - shape = desc.shape else: # Literal list or tuple, add as constant and use shape arrname = [v if isinstance(v, Number) else self._parse_value(v) for v in arrname] carr = numpy.array(arrname, dtype=dtypes.typeclass(int).type) cname = self.sdfg.find_new_constant(f'__ind{i}_{aname}') self.sdfg.add_constant(cname, carr) constant_indices[i] = cname - shape = carr.shape - - if output_shape is not None and tuple(shape) != output_shape: - raise IndexError(f'Mismatch in array index shapes in access of ' - f'"{aname}": {arrname} (shape {shape}) ' - f'does not match existing shape {output_shape}') - elif output_shape is None: - output_shape = tuple(shape) # Check subset shapes for matching the array shapes input_index = [] From e680ef60b5ae7beb8d2bc8519348741d5288b5e4 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 26 Dec 2024 15:27:39 -0800 Subject: [PATCH 03/21] Add tests --- tests/numpy/advanced_indexing_test.py | 107 ++++++++++++++++++++++++++ tests/numpy/split_test.py | 19 +++++ 2 files changed, 126 insertions(+) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index fd5e92d34d..a163f368ff 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -324,6 +324,107 @@ def indexing_test(A: dace.float64[N, N, N]): assert np.allclose(A, ref) +@pytest.mark.parametrize('contiguous', (False, True)) +def test_multidim_tuple_index(contiguous): + + if contiguous: + + @dace.program + def indexing_test(A: dace.float64[N, M]): + return A[:, (1, 2, 3)] + else: + + @dace.program + def indexing_test(A: dace.float64[N, M]): + return A[:, (1, 3, 0)] + + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (N, 3) + + A = np.random.rand(20, 10) + if contiguous: + ref = A[:, (1, 2, 3)] + else: + ref = A[:, (1, 3, 0)] + + res = indexing_test(A) + + assert np.allclose(res, ref) + + +def test_multidim_tuple_index_longer(): + + @dace.program + def indexing_test(A: dace.float64[N, M]): + return A[:, (1, 2, 3, 4, 5, 7)] + + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (N, 6) + + A = np.random.rand(20, 10) + ref = A[:, (1, 2, 3, 4, 5, 7)] + + res = indexing_test(A) + + assert np.allclose(res, ref) + + +def test_multidim_tuple_multidim_index(): + + @dace.program + def indexing_test(A: dace.float64[N, M, N]): + return A[:, (1, 2, 3, 4, 5, 7), (0, 1)] + + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (N, 6, 2) + + A = np.random.rand(20, 10, 20) + ref = A[:, (1, 2, 3, 4, 5, 7), (0, 1)] + + res = indexing_test(A) + + assert np.allclose(res, ref) + + +def test_advanced_index_broadcasting(): + + @dace.program + def indexing_test(A: dace.float64[N, N, N], indices: dace.int32[3, 3]): + return A[indices, (1, 2, 4), :] + + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (3, 3, N) + + A = np.random.rand(20, 10, 20) + indices = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) + ref = A[indices, (1, 2, 4), :] + + res = indexing_test(A, indices) + + assert np.allclose(res, ref) + + +def test_combining_basic_and_advanced_indexing(): + + @dace.program + def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): + return A[:5, indices, indices2, ..., 1:3, 4] + + n = 6 + A = np.random.rand(n, n, n, n, n, n, n) + indices = np.random.randint(0, n, size=(3, 3)) + indices2 = np.random.randint(0, n, size=(3, 3, 3)) + ref = A[:5, indices, indices2, ..., 1:3, 4] + + # Advanced indexing dimensions should be prepended to the shape + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (3, 3, 3, 5, N, N, 2) + + res = indexing_test(A, indices, indices2) + + assert np.allclose(res, ref) + + if __name__ == '__main__': test_flat() test_flat_noncontiguous() @@ -348,3 +449,9 @@ def indexing_test(A: dace.float64[N, N, N]): test_out_index_intarr_multidim() test_advanced_indexing_syntax(False) test_advanced_indexing_syntax(True) + test_multidim_tuple_index(False) + test_multidim_tuple_index(True) + test_multidim_tuple_index_longer() + test_multidim_tuple_multidim_index() + test_advanced_index_broadcasting() + test_combining_basic_and_advanced_indexing() diff --git a/tests/numpy/split_test.py b/tests/numpy/split_test.py index e4088754e8..1fee72fc6d 100644 --- a/tests/numpy/split_test.py +++ b/tests/numpy/split_test.py @@ -127,6 +127,24 @@ def test_dsplit_4d(): return a, b, c +def test_compiletime_split(): + + @dace.program + def tester(x, y, in_indices: dace.compiletime, out_index: dace.compiletime): + x0, x1, x2, x3, x4, x5 = np.split(x[:, in_indices], 6, axis=1) + factor = 1 / 12 + o = out_index + y[:, o:o + 1] = factor * (-(x1 + x2) + (x0 + x1) - (x0 + x4) + (x3 + x4) + (x2 + x5) - (x3 + x5)) + + x = np.random.rand(1000, 8) + y = np.empty_like(x) + tester(x, y, (1, 2, 3, 4, 5, 7), 0) + ref = np.empty_like(y) + tester.f(x, ref, (1, 2, 3, 4, 5, 7), 0) + + assert np.allclose(y, ref) + + if __name__ == "__main__": test_split() test_uneven_split_fail() @@ -140,3 +158,4 @@ def test_dsplit_4d(): test_vsplit() test_hsplit() test_dsplit_4d() + test_compiletime_split() From b700d0cd2a2cd74605fb509caa69a1aa34f9c190 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 29 Dec 2024 16:47:52 -0800 Subject: [PATCH 04/21] Add more tests --- tests/numpy/advanced_indexing_test.py | 35 ++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index a163f368ff..926e2b3682 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -288,7 +288,6 @@ def indexing_test(A: dace.float64[N, N, N], B: dace.float64[N, N], indices: dace assert np.allclose(A, ref) - def test_out_index_intarr_multidim(): @dace.program @@ -303,6 +302,20 @@ def indexing_test(A: dace.float64[N, N, N], indices: dace.int32[M]): assert np.allclose(A, ref) +def test_out_index_intarr_multidim_range(): + + @dace.program + def indexing_test(A: dace.float64[N, N, N], indices: dace.int32[M]): + A[1:2, indices, 3:10] = 2 + + A = np.random.rand(20, 20, 20) + indices = [1, 10, 15] + ref = np.copy(A) + ref[1:2, indices, 3:10] = 2 + indexing_test(A, indices, M=3) + + assert np.allclose(A, ref) + @pytest.mark.parametrize('tuple_index', (False, True)) def test_advanced_indexing_syntax(tuple_index): @@ -425,6 +438,24 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 assert np.allclose(res, ref) +def test_combining_basic_and_advanced_indexing_write(): + + @dace.program + def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): + A[:5, indices, indices2, ..., 1:3, 4] = 2 + + n = 6 + A = np.random.rand(n, n, n, n, n, n, n) + indices = np.random.randint(0, n, size=(3, 3)) + indices2 = np.random.randint(0, n, size=(3, 3, 3)) + ref = np.copy(A) + A[:5, indices, indices2, ..., 1:3, 4] = 2 + + # Advanced indexing dimensions should be prepended to the shape + res = indexing_test(A, indices, indices2) + + assert np.allclose(res, ref) + if __name__ == '__main__': test_flat() test_flat_noncontiguous() @@ -447,6 +478,7 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 test_out_index_intarr_aug() test_out_index_intarr_aug_bcast() test_out_index_intarr_multidim() + test_out_index_intarr_multidim_range() test_advanced_indexing_syntax(False) test_advanced_indexing_syntax(True) test_multidim_tuple_index(False) @@ -455,3 +487,4 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 test_multidim_tuple_multidim_index() test_advanced_index_broadcasting() test_combining_basic_and_advanced_indexing() + test_combining_basic_and_advanced_indexing_write() From e9f0a3a9f69279de03c5362604cee49af8b43d01 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 29 Dec 2024 18:11:03 -0800 Subject: [PATCH 05/21] Even more tests --- tests/numpy/advanced_indexing_test.py | 69 +++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index 926e2b3682..15e9a9665f 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -288,6 +288,7 @@ def indexing_test(A: dace.float64[N, N, N], B: dace.float64[N, N], indices: dace assert np.allclose(A, ref) + def test_out_index_intarr_multidim(): @dace.program @@ -302,6 +303,7 @@ def indexing_test(A: dace.float64[N, N, N], indices: dace.int32[M]): assert np.allclose(A, ref) + def test_out_index_intarr_multidim_range(): @dace.program @@ -383,20 +385,23 @@ def indexing_test(A: dace.float64[N, M]): def test_multidim_tuple_multidim_index(): + with pytest.raises(SyntaxError, match='could not be broadcast together'): - @dace.program - def indexing_test(A: dace.float64[N, M, N]): - return A[:, (1, 2, 3, 4, 5, 7), (0, 1)] + @dace.program + def indexing_test(A: dace.float64[N, M, N]): + return A[:, (1, 2, 3, 4, 5, 7), (0, 1)] - sdfg = indexing_test.to_sdfg() - assert tuple(sdfg.arrays['__return'].shape) == (N, 6, 2) + indexing_test.to_sdfg() - A = np.random.rand(20, 10, 20) - ref = A[:, (1, 2, 3, 4, 5, 7), (0, 1)] - res = indexing_test(A) +def test_multidim_tuple_multidim_index_write(): + with pytest.raises(SyntaxError, match='could not be broadcast together'): - assert np.allclose(res, ref) + @dace.program + def indexing_test(A: dace.float64[N, M, N]): + A[:, (1, 2, 3, 4, 5, 7), (0, 1)] = 2 + + indexing_test.to_sdfg() def test_advanced_index_broadcasting(): @@ -456,6 +461,49 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 assert np.allclose(res, ref) + +def test_combining_basic_and_advanced_indexing_with_newaxes(): + + @dace.program + def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): + return A[None, :5, indices, indices2, ..., 1:3, 4, np.newaxis] + + n = 6 + A = np.random.rand(n, n, n, n, n, n, n) + indices = np.random.randint(0, n, size=(3, 3)) + indices2 = np.random.randint(0, n, size=(3, 3, 3)) + ref = A[None, :5, indices, indices2, ..., 1:3, 4, np.newaxis] + + # Advanced indexing dimensions should be prepended to the shape + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (3, 3, 3, 1, 5, N, N, 2, 1) + + res = indexing_test(A, indices, indices2) + + assert np.allclose(res, ref) + + +def test_combining_basic_and_advanced_indexing_with_newaxes_2(): + + @dace.program + def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): + return A[None, :5, indices, indices2, ..., 1:3, np.newaxis] + + n = 6 + A = np.random.rand(n, n, n, n, n, n, n) + indices = np.random.randint(0, n, size=(3, 3)) + indices2 = np.random.randint(0, n, size=(3, 3, 3)) + ref = A[None, :5, indices, indices2, ..., 1:3, np.newaxis] + + # Advanced indexing dimensions should be prepended to the shape + sdfg = indexing_test.to_sdfg() + assert tuple(sdfg.arrays['__return'].shape) == (1, 5, 3, 3, 3, N, N, 2, 1) + + res = indexing_test(A, indices, indices2) + + assert np.allclose(res, ref) + + if __name__ == '__main__': test_flat() test_flat_noncontiguous() @@ -485,6 +533,9 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 test_multidim_tuple_index(True) test_multidim_tuple_index_longer() test_multidim_tuple_multidim_index() + test_multidim_tuple_multidim_index_write() test_advanced_index_broadcasting() test_combining_basic_and_advanced_indexing() test_combining_basic_and_advanced_indexing_write() + test_combining_basic_and_advanced_indexing_with_newaxes() + test_combining_basic_and_advanced_indexing_with_newaxes_2() From 4ef1d55e6048eed4a2b624c75e0006bf3e8c8111 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 29 Dec 2024 23:33:29 -0800 Subject: [PATCH 06/21] Make broadcast_together/broadcast_to public functions --- dace/frontend/python/newast.py | 8 ++++---- dace/frontend/python/replacements.py | 24 ++++++++++++------------ tests/numpy/advanced_indexing_test.py | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 9f07347146..c985fd7bd7 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -40,7 +40,7 @@ # register replacements in oprepo import dace.frontend.python.replacements -from dace.frontend.python.replacements import _sym_type, _broadcast_to, _broadcast_together +from dace.frontend.python.replacements import _sym_type, broadcast_to, broadcast_together # Type hints Size = Union[int, dace.symbolic.symbol] @@ -2711,7 +2711,7 @@ def _add_assignment(self, if (indirect_indices or boolarr or len(ssize) != len(osize) or any(inequal_symbols(s, o) for s, o in zip(ssize, osize)) or op): - _, all_idx_tuples, _, _, inp_idx = _broadcast_to(squeezed.size(), op_subset.size()) + _, all_idx_tuples, _, _, inp_idx = broadcast_to(squeezed.size(), op_subset.size()) idx = iter(i for i, _ in all_idx_tuples) target_index = ','.join( @@ -2927,7 +2927,7 @@ def _add_aug_assignment(self, wsqz = sqz_wsub.squeeze() sqz_rsub = copy.deepcopy(rtarget_subset) rsqz = sqz_rsub.squeeze() - _, all_idx_tuples, _, out_idx, inp_idx = _broadcast_to(sqz_wsub.size(), sqz_osub.size()) + _, all_idx_tuples, _, out_idx, inp_idx = broadcast_to(sqz_wsub.size(), sqz_osub.size()) # Re-add squeezed dimensions from original subset so that memlets match original arrays osqueezed = [i for i in range(len(op_subset)) if i not in osqz] wsqueezed = [i for i in range(len(wtarget_subset)) if i not in wsqz] @@ -5409,7 +5409,7 @@ def _create_output_shape_from_advanced_indexing(self, aname: str, expr: MemletEx shape = carr.shape if chunk_shape is not None: - chunk_shape, *_ = _broadcast_together(shape, chunk_shape) + chunk_shape, *_ = broadcast_together(shape, chunk_shape) else: chunk_shape = tuple(shape) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index c5b3e3b2a2..1c8baf891f 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -664,7 +664,7 @@ def _linspace(pv: ProgramVisitor, start_shape = sdfg.arrays[start].shape if (isinstance(start, str) and start in sdfg.arrays) else [] stop_shape = sdfg.arrays[stop].shape if (isinstance(stop, str) and stop in sdfg.arrays) else [] - shape, ranges, outind, ind1, ind2 = _broadcast_together(start_shape, stop_shape) + shape, ranges, outind, ind1, ind2 = broadcast_together(start_shape, stop_shape) shape_with_axis = _add_axis_to_shape(shape, axis, num) ranges_with_axis = _add_axis_to_shape(ranges, axis, ('__sind', f'0:{symbolic.symstr(num)}')) if outind: @@ -1278,10 +1278,10 @@ def _array_array_where(visitor: ProgramVisitor, right_shape = right_arr.shape if right_arr else [1] cond_shape = cond_arr.shape if cond_arr else [1] - (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape) + (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape) # Broadcast condition with broadcasted left+right - _, _, _, cond_idx, _ = _broadcast_together(cond_shape, out_shape) + _, _, _, cond_idx, _ = broadcast_together(cond_shape, out_shape) # Fix for Scalars if isinstance(left_arr, data.Scalar): @@ -1356,10 +1356,10 @@ def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str): return name -def _broadcast_to(target_shape, operand_shape): +def broadcast_to(target_shape, operand_shape): # the difference to normal broadcasting is that the broadcasted shape is the same as the target # I was unable to find documentation for this in numpy, so we follow the description from ONNX - results = _broadcast_together(target_shape, operand_shape, unidirectional=True) + results = broadcast_together(target_shape, operand_shape, unidirectional=True) # the output_shape should be equal to the target_shape assert all(i == o for i, o in zip(target_shape, results[0])) @@ -1367,7 +1367,7 @@ def _broadcast_to(target_shape, operand_shape): return results -def _broadcast_together(arr1_shape, arr2_shape, unidirectional=False): +def broadcast_together(arr1_shape, arr2_shape, unidirectional=False): all_idx_dict, all_idx, a1_idx, a2_idx = {}, [], [], [] @@ -1415,9 +1415,9 @@ def get_idx(i): all_idx_dict[get_idx(i)] = dim1 else: if unidirectional: - raise SyntaxError(f"could not broadcast input array from shape {arr2_shape} into shape {arr1_shape}") + raise IndexError(f"could not broadcast input array from shape {arr2_shape} into shape {arr1_shape}") else: - raise SyntaxError("operands could not be broadcast together with shapes {}, {}".format( + raise IndexError("operands could not be broadcast together with shapes {}, {}".format( arr1_shape, arr2_shape)) def to_string(idx): @@ -1435,7 +1435,7 @@ def _binop(sdfg: SDFG, state: SDFGState, op1: str, op2: str, opcode: str, opname arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] - out_shape, all_idx_tup, all_idx, arr1_idx, arr2_idx = _broadcast_together(arr1.shape, arr2.shape) + out_shape, all_idx_tup, all_idx, arr1_idx, arr2_idx = broadcast_together(arr1.shape, arr2.shape) name, _ = sdfg.add_temp_transient(out_shape, restype, arr1.storage) state.add_mapped_tasklet("_%s_" % opname, @@ -1816,7 +1816,7 @@ def _array_array_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le left_shape = left_arr.shape right_shape = right_arr.shape - (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape) + (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape) # Fix for Scalars if isinstance(left_arr, data.Scalar): @@ -1884,7 +1884,7 @@ def _array_const_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le if right_cast is not None: tasklet_args[1] = "{c}({o})".format(c=str(right_cast).replace('::', '.'), o=tasklet_args[1]) - (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape) + (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape) out_operand, out_arr = sdfg.add_temp_transient(out_shape, result_type, storage) @@ -1954,7 +1954,7 @@ def _array_sym_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, left if right_cast is not None: tasklet_args[1] = "{c}({o})".format(c=str(right_cast).replace('::', '.'), o=tasklet_args[1]) - (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape) + (out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape) out_operand, out_arr = sdfg.add_temp_transient(out_shape, result_type, storage) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index 15e9a9665f..f47d889497 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -385,7 +385,7 @@ def indexing_test(A: dace.float64[N, M]): def test_multidim_tuple_multidim_index(): - with pytest.raises(SyntaxError, match='could not be broadcast together'): + with pytest.raises(IndexError, match='could not be broadcast together'): @dace.program def indexing_test(A: dace.float64[N, M, N]): @@ -395,7 +395,7 @@ def indexing_test(A: dace.float64[N, M, N]): def test_multidim_tuple_multidim_index_write(): - with pytest.raises(SyntaxError, match='could not be broadcast together'): + with pytest.raises(IndexError, match='could not be broadcast together'): @dace.program def indexing_test(A: dace.float64[N, M, N]): From 2e4eae31474b74939ef7eb991a33b046dc07298f Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 29 Dec 2024 23:55:38 -0800 Subject: [PATCH 07/21] New issue uncovered by new test --- tests/numpy/advanced_indexing_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index f47d889497..6ed583493e 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -46,6 +46,17 @@ def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): assert np.allclose(A[1:5, ..., 0], res) +def test_ellipsis_and_newaxis(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + return A[None, 1:5, ..., 0] + + A = np.random.rand(5, 5, 5, 5, 5) + res = indexing_test(A) + assert np.allclose(A[None, 1:5, ..., 0], res) + + def test_aug_implicit(): @dace.program @@ -508,6 +519,7 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 test_flat() test_flat_noncontiguous() test_ellipsis() + test_ellipsis_and_newaxis() test_aug_implicit() test_ellipsis_aug() test_newaxis() From 88cdf896df55adec8eaab886fc08a990c1f9c1b3 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 30 Dec 2024 00:02:44 -0800 Subject: [PATCH 08/21] Fix bug in memlet parsing when both ellipsis and newaxis are used --- dace/frontend/python/memlet_parser.py | 19 ++++++++++++------- tests/numpy/advanced_indexing_test.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index a95bf82046..af6d9eebcb 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -15,7 +15,6 @@ MemletType = Union[ast.Call, ast.Attribute, ast.Subscript, ast.Name] - if sys.version_info < (3, 8): _simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) BytesConstant = ast.Bytes @@ -107,6 +106,11 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): idx = 0 new_idx = 0 has_ellipsis = False + + # Count new axes + num_new_axes = sum(1 for dim in ast_ndslice + if (dim is None or (isinstance(dim, (ast.Constant, NameConstant)) and dim.value is None))) + for dim in ast_ndslice: if isinstance(dim, (str, list, slice)): dim = ast.Name(id=dim) @@ -136,7 +140,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): if has_ellipsis: raise IndexError('an index can only have a single ellipsis ("...")') has_ellipsis = True - remaining_dims = len(ast_ndslice) - idx - 1 + remaining_dims = len(ast_ndslice) - num_new_axes - idx - 1 for j in range(idx, len(ndslice) - remaining_dims): ndslice[j] = (0, array.shape[j] - 1, 1) idx += 1 @@ -170,7 +174,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): if desc.dtype == dtypes.bool: # Boolean array indexing if len(ast_ndslice) > 1: - raise IndexError(f'Invalid indexing into array "{dim.id}". ' 'Only one boolean array is allowed.') + raise IndexError(f'Invalid indexing into array "{dim.id}". Only one boolean array is allowed.') if tuple(desc.shape) != tuple(array.shape): raise IndexError(f'Invalid indexing into array "{dim.id}". ' 'Shape of boolean index must match original array.') @@ -251,9 +255,9 @@ def parse_memlet_subset(array: data.Data, # Loop over the N dimensions ndslice, offsets, new_extra_dims, arrdims = _fill_missing_slices(das, ast_ndslice, narray, offsets) if new_extra_dims and idx != (len(ast_ndslices) - 1): - raise NotImplementedError('New axes only implemented for last ' 'slice') + raise NotImplementedError('New axes only implemented for last slice') if arrdims and len(ast_ndslices) != 1: - raise NotImplementedError('Array dimensions not implemented ' 'for consecutive subscripts') + raise NotImplementedError('Array dimensions not implemented for consecutive subscripts') extra_dims = new_extra_dims subset_array.append(_ndslice_to_subset(ndslice)) @@ -304,8 +308,9 @@ def ParseMemlet(visitor, try: subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) except IndexError: - raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. ' - f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}') + raise DaceSyntaxError( + visitor, node, 'Failed to parse memlet expression due to dimensionality. ' + f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}') # If undefined, default number of accesses is the slice size if num_accesses is None: diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index 6ed583493e..24e84c33fa 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -57,6 +57,17 @@ def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): assert np.allclose(A[None, 1:5, ..., 0], res) +def test_ellipsis_and_newaxis_2(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + return A[None, 1:5, ..., None, 2] + + A = np.random.rand(5, 5, 5, 5, 5) + res = indexing_test(A) + assert np.allclose(A[None, 1:5, ..., None, 2], res) + + def test_aug_implicit(): @dace.program @@ -520,6 +531,7 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 test_flat_noncontiguous() test_ellipsis() test_ellipsis_and_newaxis() + test_ellipsis_and_newaxis_2() test_aug_implicit() test_ellipsis_aug() test_newaxis() From d46e85934970a02b4d7a9726602ba1d29b61b8c0 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 30 Dec 2024 00:12:16 -0800 Subject: [PATCH 09/21] Fix test --- tests/numpy/advanced_indexing_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index 24e84c33fa..b360a4c7fd 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -519,7 +519,7 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 # Advanced indexing dimensions should be prepended to the shape sdfg = indexing_test.to_sdfg() - assert tuple(sdfg.arrays['__return'].shape) == (1, 5, 3, 3, 3, N, N, 2, 1) + assert tuple(sdfg.arrays['__return'].shape) == (1, 5, 3, 3, 3, N, N, N, 2, 1) res = indexing_test(A, indices, indices2) From 13f7d5d4a51075cbb2403fc3433f316e8bed3723 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 30 Dec 2024 00:39:45 -0800 Subject: [PATCH 10/21] Add support for new axes in output shape computation, fix test --- dace/frontend/python/newast.py | 11 +++++++++++ tests/python_frontend/callee_autodetect_test.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c985fd7bd7..f121b1a2ae 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3439,6 +3439,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): defined_arrays = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.scope_arrays, **self.defined}) expr: MemletExpr = ParseMemlet(self, defined_arrays, true_target, nslice) + + # TODO: Use _create_output_shape_from_advanced_indexing rng = expr.subset if isinstance(rng, subsets.Indices): rng = subsets.Range.from_indices(rng) @@ -5386,6 +5388,15 @@ def _create_output_shape_from_advanced_indexing(self, aname: str, expr: MemletEx else: dim_position = advanced_dims[0] + # Add new axes + for new_axis in expr.new_axes: + if prefix_dims: + output_shape.insert(new_axis + 1, 1) + else: + output_shape.insert(new_axis, 1) + if new_axis <= dim_position: + dim_position += 1 + # Contract contiguous None dimensions that appear multiple times in a row output_shape = [ s for i, s in enumerate(output_shape) if s is not None or i == 0 or output_shape[i - 1] is not None diff --git a/tests/python_frontend/callee_autodetect_test.py b/tests/python_frontend/callee_autodetect_test.py index 6fb786982a..1d8528f9fc 100644 --- a/tests/python_frontend/callee_autodetect_test.py +++ b/tests/python_frontend/callee_autodetect_test.py @@ -325,7 +325,7 @@ def outer(a: dace.float64[20]): a_ref = A * 2 if decorated: - with pytest.raises(SyntaxError): + with pytest.raises(IndexError): outer(A) else: outer(A) From 214e7469db785689787b6c51e809c6d9c94b22ed Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 1 Jan 2025 20:32:40 -0800 Subject: [PATCH 11/21] Correctness leftovers --- dace/frontend/python/newast.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index f121b1a2ae..2df6981b15 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2496,7 +2496,7 @@ def visit_While(self, node: ast.While): if astr not in self.defined: raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom) # Add to global SDFG symbols if not a scalar - if (astr not in self.sdfg.symbols and astr not in self.variables): + if (astr not in self.sdfg.symbols and astr not in self.variables and astr not in self.sdfg.arrays): self.sdfg.add_symbol(astr, atom.dtype) # Handle else clause @@ -3439,7 +3439,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): defined_arrays = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.scope_arrays, **self.defined}) expr: MemletExpr = ParseMemlet(self, defined_arrays, true_target, nslice) - + # TODO: Use _create_output_shape_from_advanced_indexing rng = expr.subset if isinstance(rng, subsets.Indices): @@ -5255,7 +5255,8 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): # Obtain array/tuple node_parsed = self._gettype(node.value) - if len(node_parsed) > 1: + if len(node_parsed) > 1 or (len(node_parsed) == 1 and isinstance(node_parsed[0], tuple) + and node_parsed[0][1] in ('symbol', 'NumConstant')): # If the value is a tuple of constants (e.g., array.shape) and the # slice is constant, return the value itself nslice = self.visit(node.slice) From e44beeecf6dfb45ca2dc036bf056be9bb6af36b9 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 1 Jan 2025 20:33:45 -0800 Subject: [PATCH 12/21] Fully implement advanced indexing and reimplement `_array_indirection_subgraph` to use it for reading --- dace/frontend/python/newast.py | 265 ++++++++++++++++++++++----------- 1 file changed, 180 insertions(+), 85 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 2df6981b15..8627bda6dd 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -5362,9 +5362,9 @@ def make_slice(self, arrname: str, rng: subsets.Range): rnode, wnode, Memlet.simple(array, rng, num_accesses=rng.num_elements(), other_subset_str=other_subset)) return tmp, other_subset - def _create_output_shape_from_advanced_indexing(self, aname: str, expr: MemletExpr) -> List[symbolic.SymbolicType]: + def _compute_output_shape_from_advanced_indexing(self, aname: str, expr: MemletExpr) -> List[symbolic.SymbolicType]: """ - Creates the output shape of a slicing operation with advanced indexing. + Computes the output shape of a slicing operation with advanced indexing. :param aname: The name of the array being sliced. :param expr: The MemletExpr object representing the slicing operation. @@ -5406,7 +5406,7 @@ def _create_output_shape_from_advanced_indexing(self, aname: str, expr: MemletEx # Broadcast all advanced indexing expressions together chunk_shape = None # Get the advanced indexing expressions - for i, arrname in expr.arrdims.items(): + for _, arrname in expr.arrdims.items(): if isinstance(arrname, str): # Array or constant if arrname in self.sdfg.arrays: desc = self.sdfg.arrays[arrname] @@ -5430,97 +5430,192 @@ def _create_output_shape_from_advanced_indexing(self, aname: str, expr: MemletEx return output_shape - def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) -> str: - aname = rnode.data - idesc = self.sdfg.arrays[aname] + def _create_memlets_from_advanced_indexing( + self, aname: str, expr: MemletExpr) -> Tuple[Dict[str, subsets.Range], Memlet, Memlet, List[Memlet]]: + """ + Creates the input memlets and index expression of a slicing operation with advanced indexing. + Returns four elements: a dictionary mapping an index name (e.g., ``__i0``) to the range to access, a memlet + representing the array access for reading, a memlet representing the array access for writing, and a list of + memlets representing any other input index arrays. - if expr.new_axes: - # NOTE: Matching behavior with numpy would be to append all new - # axes in the end - raise IndexError('New axes unsupported when array indices are used') + This method also creates new constant arrays for indexing arrays used in advanced indexing expressions. - # Create output shape dimensions based on the sizes of the arrays - output_shape = self._create_output_shape_from_advanced_indexing(aname, expr) - # Create constants for array indices - constant_indices: Dict[int, str] = {} - for i, arrname in expr.arrdims.items(): - if isinstance(arrname, str): # Array or constant - if arrname in self.sdfg.constants: - constant_indices[i] = arrname - elif arrname in self.sdfg.arrays: - pass + :param aname: The name of the array being sliced. + :param expr: The MemletExpr object representing the slicing operation. + :return: A tuple of (index mapping, input array memlet, output array memlet, input index memlets). + """ + # Compute input subset and index mapping for the basic indexing expressions (ranges and scalars) + ndrange = expr.subset.ndrange() + output_ndrange = [(dace.symbol(f'__i{i}'), dace.symbol(f'__i{i}'), 1) if rng[0] != rng[1] else (0, 0, 1) + for i, rng in enumerate(ndrange)] + # Create the input index subset by offsetting the output index subset + input_subset = subsets.Range([(rb + ind * rs, rb + ind * rs, 1) + for (rb, _, rs), (ind, _, _) in zip(ndrange, output_ndrange)]) + index_mapping = { + f'__i{i}': (0, s - 1, 1) + for i, (s, rng) in enumerate(zip(expr.subset.size(), ndrange)) if rng[0] != rng[1] + } + index_memlets: List[Memlet] = [] + + # Fast path: no advanced indexing + if not expr.arrdims: + for new_axis in reversed(expr.new_axes): + output_ndrange.insert(new_axis, (0, 0, 1)) + return ( + index_mapping, + Memlet(data=aname, subset=input_subset), + Memlet(data=aname, subset=subsets.Range(output_ndrange)), + index_memlets, + ) + + # The output shape is the shape of all contiguous advanced indexing arrays, after broadcasting with each other + + # Start with all basic indexing dimensions, setting advanced indexing dimensions to None + output_shape = [s if i not in expr.arrdims else None for i, s in enumerate(expr.subset.size())] + + # If any advanced indexing is found, mark any scalar dimension as advanced indices too + output_shape = [None if rng[0] == rng[1] else s for s, rng in zip(output_shape, expr.subset.ndrange())] + + # Create output subset based on the shape + output_ndrange = [None if output_shape[i] is None else rng for i, rng in enumerate(output_ndrange)] + + # Mark every dimension that starts with None as an advanced indexing "chunk" + advanced_dims = [ + i for i, s in enumerate(output_shape) if s is None and (i == 0 or output_shape[i - 1] is not None) + ] + # If there is more than one contiguous advanced indexing chunk, move all advanced indices to the beginning + prefix_dims = len(advanced_dims) > 1 + if prefix_dims: + output_shape = [None] + [s for s in output_shape if s is not None] + output_ndrange = [None] + [rng for rng in output_ndrange if rng is not None] + dim_position = 0 + else: + dim_position = advanced_dims[0] + + # Add new axes + for new_axis in reversed(expr.new_axes): + if prefix_dims: + output_shape.insert(new_axis + 1, 1) + output_ndrange.insert(new_axis + 1, (0, 0, 1)) + else: + output_shape.insert(new_axis, 1) + output_ndrange.insert(new_axis, (0, 0, 1)) + if new_axis <= dim_position: + dim_position += 1 + + # Contract contiguous None dimensions that appear multiple times in a row + output_shape = [ + s for i, s in enumerate(output_shape) if s is not None or i == 0 or output_shape[i - 1] is not None + ] + output_ndrange = [ + rng for i, rng in enumerate(output_ndrange) + if rng is not None or i == 0 or output_ndrange[i - 1] is not None + ] + + # Broadcast all advanced indexing expressions together + advidx_shape = None + out_idx = None + advidx_arrays = {} + # Get the advanced indexing expressions + for i, idxarrname in expr.arrdims.items(): + if isinstance(idxarrname, str): # Array or constant + if idxarrname in self.sdfg.arrays: + desc = self.sdfg.arrays[idxarrname] + elif idxarrname in self.sdfg.constants: + desc = self.sdfg.constants[idxarrname] else: - raise NameError(f'Array "{arrname}" used in indexing "{aname}" not found') + raise NameError(f'Array "{idxarrname}" used in indexing "{aname}" not found') + shape = desc.shape else: # Literal list or tuple, add as constant and use shape - arrname = [v if isinstance(v, Number) else self._parse_value(v) for v in arrname] - carr = numpy.array(arrname, dtype=dtypes.typeclass(int).type) + idxarr = [v if isinstance(v, Number) else self._parse_value(v) for v in idxarrname] + carr = numpy.array(idxarr, dtype=dtypes.typeclass(int).type) cname = self.sdfg.find_new_constant(f'__ind{i}_{aname}') self.sdfg.add_constant(cname, carr) - constant_indices[i] = cname - - # Check subset shapes for matching the array shapes - input_index = [] - i0 = symbolic.pystr_to_symbolic('__i0') - for i, elem in enumerate(expr.subset.size()): - if i in expr.arrdims: - input_index.append((0, elem - 1, 1)) - continue - if len(output_shape) > 1: - raise IndexError('Combining multidimensional array indices and ' - 'numeric subsets is unsupported (array ' - f'"{aname}").') - if (elem, ) != output_shape: - # TODO(later): Properly broadcast multiple (and missing) shapes - raise IndexError(f'Mismatch in array index shapes in access of ' - f'"{aname}": Subset {expr.subset[i]} ' - f'does not match existing shape {output_shape}') - - # Since there can only be one-dimensional outputs if arrays and - # subsets are both involved, express memlet as a function of _i0 - rb, _, rs = expr.subset[i] - input_index.append((rb + i0 * rs, rb + i0 * rs, 1)) + self.sdfg.arrays[cname] = self.sdfg.constants_prop[cname][0] + self.sdfg.arrays[cname].transient = True + idxarrname = cname + shape = carr.shape - outname, _ = self.sdfg.add_temp_transient(output_shape, idesc.dtype) + # Set the actual name of the advanced indexing array + advidx_arrays[i] = (idxarrname, shape) + + # Loop once to get the broadcasted shape + if advidx_shape is not None: + advidx_shape, _, out_idx, *_ = broadcast_together(shape, advidx_shape) + else: + advidx_shape = tuple(shape) + out_idx = ','.join([f'__i{i}' for i in range(len(shape))]) + + # Rename indices to avoid conflicts + out_idx = out_idx.replace('__i', '__ind') + advidx_index = [] + # Set the index mapping for the broadcasted array + for idx, s in zip(out_idx.split(','), advidx_shape): + index_mapping[idx.strip()] = (0, s - 1, 1) + symidx = symbolic.symbol(idx.strip()) + advidx_index.append((symidx, symidx, 1)) + + # Loop over the advanced indexing expressions again to create the input memlets + for i, (idxarrname, shape) in advidx_arrays.items(): + # Remove the original index dimension from index mapping + del index_mapping[f'__i{i}'] + + # NOTE: The indices can be multi-dimensional + _, _, out_idx, arr_idx, _ = broadcast_together(shape, advidx_shape) + + # Create the input memlet for this advanced indexing array based on broadcasting rules + arr_idx = arr_idx.replace('__i', '__ind').split(',') + arr_subset = subsets.Range([(symbolic.symbol(idx.strip()), symbolic.symbol(idx.strip()), 1) + for idx in arr_idx]) + index_memlets.append(Memlet(data=idxarrname, subset=arr_subset)) + + # Set the subset of the input/output array to be the entire array + input_subset[i] = ndrange[i] - # Make slice subgraph - input shape dimensions are len(expr.subset) and - # output shape dimensions are len(output_shape) + # Replace the advanced indexing dimensions with the broadcasted shape + output_shape = output_shape[:dim_position] + list(advidx_shape) + output_shape[dim_position + 1:] + output_ndrange = output_ndrange[:dim_position] + advidx_index + output_ndrange[dim_position + 1:] + + return ( + index_mapping, + Memlet(data=aname, subset=input_subset, volume=1), + Memlet(data=aname, subset=subsets.Range(output_ndrange), volume=1), + index_memlets, + ) - # Make map with output shape - state = self.current_state - wnode = state.add_write(outname) - maprange = [(f'__i{i}', f'0:{s}') for i, s in enumerate(output_shape)] - me, mx = state.add_map('indirect_slice', maprange, debuginfo=self.current_lineinfo) - - # Make indirection tasklet for array-index dimensions - array_indices = set(expr.arrdims.keys()) - set(constant_indices.keys()) - output_str = ', '.join(ind for ind, _ in maprange) - access_str = ', '.join( - [f'__inp{i}' if i in array_indices else f'{cname}[{output_str}]' for i in expr.arrdims.keys()]) - t = state.add_tasklet('indirection', {'__arr'} | set(f'__inp{i}' for i in array_indices), {'__out'}, - f'__out = __arr[{access_str}]') - - # Offset input memlet according to offset and stride if fixed, or - # entire array with volume 1 if array-index - input_subset = subsets.Range(input_index) - state.add_edge_pair(me, t, rnode, Memlet(data=aname, subset=input_subset, volume=1), internal_connector='__arr') - # Add array-index memlets - for dim in array_indices: - arrname = expr.arrdims[dim] - arrnode = state.add_read(arrname) - state.add_edge_pair(me, - t, - arrnode, - Memlet(data=arrname, subset=subsets.Range([(ind, ind, 1) for ind, _ in maprange])), - internal_connector=f'__inp{dim}') - - # Output matches the output shape exactly - output_index = subsets.Range([(ind, ind, 1) for ind, _ in maprange]) - state.add_edge_pair(mx, - t, - wnode, - Memlet(data=outname, subset=output_index), - external_memlet=Memlet(data=outname), - internal_connector='__out') + def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) -> str: + aname = rnode.data + idesc = self.sdfg.arrays[aname] + + # Create output shape dimensions based on the sizes of the arrays + output_shape = self._compute_output_shape_from_advanced_indexing(aname, expr) + index_mapping, input_memlet, output_memlet, index_memlets = self._create_memlets_from_advanced_indexing( + aname, expr) + + # Create an output array with the right shape + outname, _ = self.sdfg.add_temp_transient(output_shape, idesc.dtype) + output_memlet.data = outname + + # Make slice subgraph - a mapped tasklet with the proper dimensions + + # Compute index expression string + access_expr = [f'__inp{i}' for i in range(len(expr.arrdims))] + access_str = ', '.join(access_expr) + + # Make mapped tasklet with the proper dimensions + self.current_state.add_mapped_tasklet( + 'indirection', + index_mapping, + inputs={ + '__arr': input_memlet, + **{f'__inp{i}': m + for i, m in enumerate(index_memlets)} + }, + outputs={'__out': output_memlet}, + code=f'__out = __arr[{access_str}]', + external_edges=True, + debuginfo=self.current_lineinfo, + ) return outname From 022405ca5864ab38ae14c405587e0774fdf26478 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 1 Jan 2025 20:34:19 -0800 Subject: [PATCH 13/21] Fix tests and remove FIXME --- tests/numpy/advanced_indexing_test.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index b360a4c7fd..a2f47ab3ab 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -161,10 +161,9 @@ def indexing_test(A: dace.float64[20, 10, 30], indices: dace.int32[3]): return A[indices, 2:7:2, [15, 10, 1]] A = np.random.rand(20, 10, 30) - indices = [1, 10, 15] + indices = np.array([1, 10, 15], dtype=np.int32) res = indexing_test(A, indices) - # FIXME: NumPy behavior is unclear in this case - assert np.allclose(np.diag(A[indices, 2:7:2, [15, 10, 1]]), res) + assert np.allclose(A[indices, 2:7:2, [15, 10, 1]], res) def test_index_intarr_nd(): @@ -429,7 +428,7 @@ def indexing_test(A: dace.float64[N, M, N]): def test_advanced_index_broadcasting(): @dace.program - def indexing_test(A: dace.float64[N, N, N], indices: dace.int32[3, 3]): + def indexing_test(A: dace.float64[N, M, N], indices: dace.int32[3, 3]): return A[indices, (1, 2, 4), :] sdfg = indexing_test.to_sdfg() @@ -452,8 +451,8 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 n = 6 A = np.random.rand(n, n, n, n, n, n, n) - indices = np.random.randint(0, n, size=(3, 3)) - indices2 = np.random.randint(0, n, size=(3, 3, 3)) + indices = np.random.randint(0, n, size=(3, 3)).astype(np.int32) + indices2 = np.random.randint(0, n, size=(3, 3, 3)).astype(np.int32) ref = A[:5, indices, indices2, ..., 1:3, 4] # Advanced indexing dimensions should be prepended to the shape @@ -473,8 +472,8 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 n = 6 A = np.random.rand(n, n, n, n, n, n, n) - indices = np.random.randint(0, n, size=(3, 3)) - indices2 = np.random.randint(0, n, size=(3, 3, 3)) + indices = np.random.randint(0, n, size=(3, 3)).astype(np.int32) + indices2 = np.random.randint(0, n, size=(3, 3, 3)).astype(np.int32) ref = np.copy(A) A[:5, indices, indices2, ..., 1:3, 4] = 2 @@ -492,8 +491,8 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 n = 6 A = np.random.rand(n, n, n, n, n, n, n) - indices = np.random.randint(0, n, size=(3, 3)) - indices2 = np.random.randint(0, n, size=(3, 3, 3)) + indices = np.random.randint(0, n, size=(3, 3)).astype(np.int32) + indices2 = np.random.randint(0, n, size=(3, 3, 3)).astype(np.int32) ref = A[None, :5, indices, indices2, ..., 1:3, 4, np.newaxis] # Advanced indexing dimensions should be prepended to the shape @@ -513,8 +512,8 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 n = 6 A = np.random.rand(n, n, n, n, n, n, n) - indices = np.random.randint(0, n, size=(3, 3)) - indices2 = np.random.randint(0, n, size=(3, 3, 3)) + indices = np.random.randint(0, n, size=(3, 3)).astype(np.int32) + indices2 = np.random.randint(0, n, size=(3, 3, 3)).astype(np.int32) ref = A[None, :5, indices, indices2, ..., 1:3, np.newaxis] # Advanced indexing dimensions should be prepended to the shape From 8ee0db8d3109d661a8c7692a81a5181954b4dfcd Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 1 Jan 2025 20:37:46 -0800 Subject: [PATCH 14/21] Make tests more challenging --- tests/numpy/advanced_indexing_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index a2f47ab3ab..a4c438d189 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -487,13 +487,13 @@ def test_combining_basic_and_advanced_indexing_with_newaxes(): @dace.program def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): - return A[None, :5, indices, indices2, ..., 1:3, 4, np.newaxis] + return A[None, :5, indices, indices2, ..., 1:6:3, 4, np.newaxis] n = 6 A = np.random.rand(n, n, n, n, n, n, n) indices = np.random.randint(0, n, size=(3, 3)).astype(np.int32) indices2 = np.random.randint(0, n, size=(3, 3, 3)).astype(np.int32) - ref = A[None, :5, indices, indices2, ..., 1:3, 4, np.newaxis] + ref = A[None, :5, indices, indices2, ..., 1:6:3, 4, np.newaxis] # Advanced indexing dimensions should be prepended to the shape sdfg = indexing_test.to_sdfg() @@ -508,13 +508,13 @@ def test_combining_basic_and_advanced_indexing_with_newaxes_2(): @dace.program def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): - return A[None, :5, indices, indices2, ..., 1:3, np.newaxis] + return A[None, :5, indices, indices2, ..., 1:6:3, np.newaxis] n = 6 A = np.random.rand(n, n, n, n, n, n, n) indices = np.random.randint(0, n, size=(3, 3)).astype(np.int32) indices2 = np.random.randint(0, n, size=(3, 3, 3)).astype(np.int32) - ref = A[None, :5, indices, indices2, ..., 1:3, np.newaxis] + ref = A[None, :5, indices, indices2, ..., 1:6:3, np.newaxis] # Advanced indexing dimensions should be prepended to the shape sdfg = indexing_test.to_sdfg() From d361fe52c1887674e0e1ac9631b17322fb75fa0e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Jan 2025 12:35:24 -0800 Subject: [PATCH 15/21] Fix corner case in pass --- .../passes/simplification/prune_empty_conditional_branches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py index a492a9a65c..8783520b63 100644 --- a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -55,7 +55,7 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]: region.remove_branch(branch) removed_branches += 1 # If the else branch remains, make sure it now has the new negate-all condition. - if new_else_cond is not None and region.branches[-1][0] is None: + if region.branches and new_else_cond is not None and region.branches[-1][0] is None: region._branches[-1] = (new_else_cond, region._branches[-1][1]) if len(region.branches) == 0: From 0e66d8ac40e31c5b9e9ebc1105b801ded205b106 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Jan 2025 12:52:44 -0800 Subject: [PATCH 16/21] Fix potential isolated node --- dace/frontend/python/newast.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 2d3cd16fc5..62023f1d1b 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -5656,6 +5656,7 @@ def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) code=f'__out = __arr[{access_str}]', external_edges=True, debuginfo=self.current_lineinfo, + input_nodes={rnode.data: rnode}, ) return outname From 4d03e9999072a551b8669fba7bac9796bba54704 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Jan 2025 12:55:52 -0800 Subject: [PATCH 17/21] Fix bad validation warning, fix test --- dace/sdfg/validation.py | 2 +- tests/numpy/split_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 7ad8ff20e1..03ba55d5d4 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -257,7 +257,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context # Ensure that there is a mentioning of constants in either the array or symbol. for const_name, (const_type, _) in sdfg.constants_prop.items(): if const_name in sdfg.arrays: - if const_type != sdfg.arrays[const_name].dtype: + if const_type.dtype != sdfg.arrays[const_name].dtype: # This should actually be an error, but there is a lots of code that depends on it. warnings.warn(f'Mismatch between constant and data descriptor of "{const_name}", ' f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".') diff --git a/tests/numpy/split_test.py b/tests/numpy/split_test.py index 1fee72fc6d..ef783768b3 100644 --- a/tests/numpy/split_test.py +++ b/tests/numpy/split_test.py @@ -137,9 +137,9 @@ def tester(x, y, in_indices: dace.compiletime, out_index: dace.compiletime): y[:, o:o + 1] = factor * (-(x1 + x2) + (x0 + x1) - (x0 + x4) + (x3 + x4) + (x2 + x5) - (x3 + x5)) x = np.random.rand(1000, 8) - y = np.empty_like(x) + y = np.zeros_like(x) tester(x, y, (1, 2, 3, 4, 5, 7), 0) - ref = np.empty_like(y) + ref = np.zeros_like(y) tester.f(x, ref, (1, 2, 3, 4, 5, 7), 0) assert np.allclose(y, ref) From f7341be9e43e301106a3b630d93d25f95f092e25 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Jan 2025 12:56:12 -0800 Subject: [PATCH 18/21] Skip write advanced indexing tests --- tests/numpy/advanced_indexing_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index cac8ad7f5f..d7c3504067 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -429,6 +429,7 @@ def indexing_test(A: dace.float64[N, M, N]): indexing_test.to_sdfg() +@pytest.mark.skip("Combined basic and advanced indexing with writes is not supported") def test_multidim_tuple_multidim_index_write(): with pytest.raises(IndexError, match='could not be broadcast together'): @@ -477,7 +478,7 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 assert np.allclose(res, ref) - +@pytest.mark.skip("Combined basic and advanced indexing with writes is not supported") def test_combining_basic_and_advanced_indexing_write(): @dace.program @@ -571,9 +572,9 @@ def indexing_test(A: dace.float64[N, N, N, N, N, N, N], indices: dace.int32[3, 3 test_multidim_tuple_index(True) test_multidim_tuple_index_longer() test_multidim_tuple_multidim_index() - test_multidim_tuple_multidim_index_write() + # test_multidim_tuple_multidim_index_write() test_advanced_index_broadcasting() test_combining_basic_and_advanced_indexing() - test_combining_basic_and_advanced_indexing_write() + # test_combining_basic_and_advanced_indexing_write() test_combining_basic_and_advanced_indexing_with_newaxes() test_combining_basic_and_advanced_indexing_with_newaxes_2() From 32b0d57d3d5ebb1c14d8f570d306054d85d6c6ae Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Jan 2025 12:56:44 -0800 Subject: [PATCH 19/21] Add a type hint for `SDFG.constants_prop` --- dace/sdfg/sdfg.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 33e5b255a9..09b2325d1c 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -175,8 +175,7 @@ class InterstateEdge(object): loop iterates). """ - assignments = Property(dtype=dict, - desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')") + assignments = Property(dtype=dict, desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')") condition = CodeProperty(desc="Transition condition", default=CodeBlock("1")) guid = Property(dtype=str, allow_none=False) @@ -214,7 +213,7 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == 'guid': # Skip ID + if k == 'guid': # Skip ID continue setattr(result, k, copy.deepcopy(v, memo)) return result @@ -416,7 +415,11 @@ class SDFG(ControlFlowRegion): name = Property(dtype=str, desc="Name of the SDFG") arg_names = ListProperty(element_type=str, desc='Ordered argument names (used for calling conventions).') - constants_prop = Property(dtype=dict, default={}, desc="Compile-time constants") + constants_prop: Dict[str, Tuple[dt.Data, Any]] = Property( + dtype=dict, + default={}, + desc='Compile-time constants. The dictionary maps between a constant name to ' + 'a tuple of its type and the actual constant data.') _arrays = Property(dtype=NestedDict, desc="Data descriptors for this SDFG", to_json=_arrays_to_json, @@ -463,7 +466,8 @@ class SDFG(ControlFlowRegion): desc='Mapping between callback name and its original callback ' '(for when the same callback is used with a different signature)') - using_explicit_control_flow = Property(dtype=bool, default=False, + using_explicit_control_flow = Property(dtype=bool, + default=False, desc="Whether the SDFG contains explicit control flow constructs") def __init__(self, @@ -612,9 +616,7 @@ def from_json(cls, json_obj, context=None): ret = SDFG(name=attrs['name'], constants=constants_prop, parent=context['sdfg']) - dace.serialize.set_properties_from_json(ret, - json_obj, - ignore_properties={'constants_prop', 'name', 'hash'}) + dace.serialize.set_properties_from_json(ret, json_obj, ignore_properties={'constants_prop', 'name', 'hash'}) nodelist = [] for n in nodes: @@ -742,7 +744,6 @@ def replace_dict(self, if symrepl: symrepl = {k: v for k, v in symrepl.items() if str(k) != str(v)} - symrepl = symrepl or { symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v for k, v in repldict.items() @@ -2318,8 +2319,7 @@ def is_loaded(self) -> bool: dll = cs.ReloadableDLL(binary_filename, self.name) return dll.is_loaded() - def compile(self, output_file=None, validate=True, - return_program_handle=True) -> 'CompiledSDFG': + def compile(self, output_file=None, validate=True, return_program_handle=True) -> 'CompiledSDFG': """ Compiles a runnable binary from this SDFG. :param output_file: If not None, copies the output library file to From ea8d8d81a131e5abf1080ad410d306dd97622892 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Jan 2025 14:52:22 -0800 Subject: [PATCH 20/21] Fix array indirection promotion for multidimensional slices with offset dimensions --- dace/sdfg/utils.py | 4 +-- .../transformation/passes/scalar_to_symbol.py | 26 +++++++++++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 8160b1de72..14c7728e61 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1609,9 +1609,9 @@ def traverse_sdfg_with_defined_symbols( :return: A generator that yields tuples of (state, node in state, currently-defined symbols) """ - # Start with global symbols + # Start with global symbols and scalar constants symbols = copy.copy(sdfg.symbols) - symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) + symbols.update({k: desc.dtype for k, (desc, _) in sdfg.constants_prop.items() if isinstance(desc, dt.Scalar)}) for desc in sdfg.arrays.values(): symbols.update({str(s): s.dtype for s in desc.free_symbols}) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 43cd45146d..e77833113e 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -331,6 +331,24 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle self.out_mapping: Dict[str, Tuple[str, subsets.Range]] = {} self.do_not_remove: Set[str] = set() + def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subset) -> subsets.Subset: + """ + Returns the requested range from a subscript node, which consists of the memlet subset composed with the + tasklet subset. + + :param node: The subscript node. + :param memlet_subset: The memlet subset. + :return: The requested range. + """ + arrname, tasklet_slice = astutils.subscript_to_ast_slice(node) + arrname = arrname if arrname in self.arrays else None + # Unsqueeze all index dimensions from orig_subset into tasklet_subset + for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))): + if start == end: + tasklet_slice.insert(i, (None, None, None)) + tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname)) + return memlet_subset.compose(tasklet_subset) + def visit_Subscript(self, node: ast.Subscript) -> Any: # Convert subscript to symbol name node = self.generic_visit(node) @@ -339,8 +357,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: new_name = dt.find_new_name(node_name, self.connector_names) self.connector_names.add(new_name) - orig_subset = self.in_edges[node_name].subset - subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1])) + subset = self._get_requested_range(node, self.in_edges[node_name].subset) # Check if range can be collapsed if _range_is_promotable(subset, self.defined): self.in_mapping[new_name] = (node_name, subset) @@ -351,8 +368,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: new_name = dt.find_new_name(node_name, self.connector_names) self.connector_names.add(new_name) - orig_subset = self.out_edges[node_name].subset - subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1])) + subset = self._get_requested_range(node, self.out_edges[node_name].subset) # Check if range can be collapsed if _range_is_promotable(subset, self.defined): self.out_mapping[new_name] = (node_name, subset) @@ -750,4 +766,4 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: return to_promote or None def report(self, pass_retval: Set[str]) -> str: - return f'Promoted {len(pass_retval)} scalars to symbols.' + return f'Promoted {len(pass_retval)} scalars to symbols: {pass_retval}' From f442867712ae9ca9604fc9bea1e4dcb6ebfa6ff6 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 11 Jan 2025 23:16:57 -0800 Subject: [PATCH 21/21] scal2sym: unsqueeze tasklet accesses only if necessary --- dace/transformation/passes/scalar_to_symbol.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index e77833113e..050cc9f44e 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -342,10 +342,11 @@ def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subse """ arrname, tasklet_slice = astutils.subscript_to_ast_slice(node) arrname = arrname if arrname in self.arrays else None - # Unsqueeze all index dimensions from orig_subset into tasklet_subset - for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))): - if start == end: - tasklet_slice.insert(i, (None, None, None)) + if len(tasklet_slice) < len(memlet_subset): + # Unsqueeze all index dimensions from orig_subset into tasklet_subset + for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))): + if start == end: + tasklet_slice.insert(i, (None, None, None)) tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname)) return memlet_subset.compose(tasklet_subset)