Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully implement NumPy advanced indexing for reads #1837

Merged
merged 24 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6077b0d
Fix tuple indexing in Python <3.9
tbennun Dec 26, 2024
6563077
Correctly compute output shape from advanced and advanced/basic index…
tbennun Dec 26, 2024
e680ef6
Add tests
tbennun Dec 26, 2024
b700d0c
Add more tests
tbennun Dec 30, 2024
e9f0a3a
Even more tests
tbennun Dec 30, 2024
4ef1d55
Make broadcast_together/broadcast_to public functions
tbennun Dec 30, 2024
2e4eae3
New issue uncovered by new test
tbennun Dec 30, 2024
88cdf89
Fix bug in memlet parsing when both ellipsis and newaxis are used
tbennun Dec 30, 2024
d46e859
Fix test
tbennun Dec 30, 2024
13f7d5d
Add support for new axes in output shape computation, fix test
tbennun Dec 30, 2024
a6c43f3
Merge branch 'main' into advanced-indexing
tbennun Dec 31, 2024
214e746
Correctness leftovers
tbennun Jan 2, 2025
e44beee
Fully implement advanced indexing and reimplement `_array_indirection…
tbennun Jan 2, 2025
022405c
Fix tests and remove FIXME
tbennun Jan 2, 2025
8ee0db8
Make tests more challenging
tbennun Jan 2, 2025
a97eec9
Merge branch 'main' into advanced-indexing
tbennun Jan 7, 2025
83361b0
Merge remote-tracking branch 'origin/main' into advanced-indexing
tbennun Jan 10, 2025
d361fe5
Fix corner case in pass
tbennun Jan 10, 2025
0e66d8a
Fix potential isolated node
tbennun Jan 10, 2025
4d03e99
Fix bad validation warning, fix test
tbennun Jan 10, 2025
f7341be
Skip write advanced indexing tests
tbennun Jan 10, 2025
32b0d57
Add a type hint for `SDFG.constants_prop`
tbennun Jan 10, 2025
ea8d8d8
Fix array indirection promotion for multidimensional slices with offs…
tbennun Jan 10, 2025
f442867
scal2sym: unsqueeze tasklet accesses only if necessary
tbennun Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

MemletType = Union[ast.Call, ast.Attribute, ast.Subscript, ast.Name]


if sys.version_info < (3, 8):
_simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
BytesConstant = ast.Bytes
Expand Down Expand Up @@ -107,6 +106,11 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
idx = 0
new_idx = 0
has_ellipsis = False

# Count new axes
num_new_axes = sum(1 for dim in ast_ndslice
if (dim is None or (isinstance(dim, (ast.Constant, NameConstant)) and dim.value is None)))

for dim in ast_ndslice:
if isinstance(dim, (str, list, slice)):
dim = ast.Name(id=dim)
Expand Down Expand Up @@ -136,7 +140,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
if has_ellipsis:
raise IndexError('an index can only have a single ellipsis ("...")')
has_ellipsis = True
remaining_dims = len(ast_ndslice) - idx - 1
remaining_dims = len(ast_ndslice) - num_new_axes - idx - 1
for j in range(idx, len(ndslice) - remaining_dims):
ndslice[j] = (0, array.shape[j] - 1, 1)
idx += 1
Expand Down Expand Up @@ -170,7 +174,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
if desc.dtype == dtypes.bool:
# Boolean array indexing
if len(ast_ndslice) > 1:
raise IndexError(f'Invalid indexing into array "{dim.id}". ' 'Only one boolean array is allowed.')
raise IndexError(f'Invalid indexing into array "{dim.id}". Only one boolean array is allowed.')
if tuple(desc.shape) != tuple(array.shape):
raise IndexError(f'Invalid indexing into array "{dim.id}". '
'Shape of boolean index must match original array.')
Expand Down Expand Up @@ -251,9 +255,9 @@ def parse_memlet_subset(array: data.Data,
# Loop over the N dimensions
ndslice, offsets, new_extra_dims, arrdims = _fill_missing_slices(das, ast_ndslice, narray, offsets)
if new_extra_dims and idx != (len(ast_ndslices) - 1):
raise NotImplementedError('New axes only implemented for last ' 'slice')
raise NotImplementedError('New axes only implemented for last slice')
if arrdims and len(ast_ndslices) != 1:
raise NotImplementedError('Array dimensions not implemented ' 'for consecutive subscripts')
raise NotImplementedError('Array dimensions not implemented for consecutive subscripts')
extra_dims = new_extra_dims
subset_array.append(_ndslice_to_subset(ndslice))

Expand Down Expand Up @@ -305,8 +309,9 @@ def ParseMemlet(visitor,
try:
subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice)
except IndexError:
raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. '
f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}')
raise DaceSyntaxError(
visitor, node, 'Failed to parse memlet expression due to dimensionality. '
f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}')

# If undefined, default number of accesses is the slice size
if num_accesses is None:
Expand Down
370 changes: 261 additions & 109 deletions dace/frontend/python/newast.py

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def _linspace(pv: ProgramVisitor,
start_shape = sdfg.arrays[start].shape if (isinstance(start, str) and start in sdfg.arrays) else []
stop_shape = sdfg.arrays[stop].shape if (isinstance(stop, str) and stop in sdfg.arrays) else []

shape, ranges, outind, ind1, ind2 = _broadcast_together(start_shape, stop_shape)
shape, ranges, outind, ind1, ind2 = broadcast_together(start_shape, stop_shape)
shape_with_axis = _add_axis_to_shape(shape, axis, num)
ranges_with_axis = _add_axis_to_shape(ranges, axis, ('__sind', f'0:{symbolic.symstr(num)}'))
if outind:
Expand Down Expand Up @@ -1325,10 +1325,10 @@ def _array_array_where(visitor: ProgramVisitor,
right_shape = right_arr.shape if right_arr else [1]
cond_shape = cond_arr.shape if cond_arr else [1]

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

# Broadcast condition with broadcasted left+right
_, _, _, cond_idx, _ = _broadcast_together(cond_shape, out_shape)
_, _, _, cond_idx, _ = broadcast_together(cond_shape, out_shape)

# Fix for Scalars
if isinstance(left_arr, data.Scalar):
Expand Down Expand Up @@ -1464,18 +1464,18 @@ def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str):
return name


def _broadcast_to(target_shape, operand_shape):
def broadcast_to(target_shape, operand_shape):
# the difference to normal broadcasting is that the broadcasted shape is the same as the target
# I was unable to find documentation for this in numpy, so we follow the description from ONNX
results = _broadcast_together(target_shape, operand_shape, unidirectional=True)
results = broadcast_together(target_shape, operand_shape, unidirectional=True)

# the output_shape should be equal to the target_shape
assert all(i == o for i, o in zip(target_shape, results[0]))

return results


def _broadcast_together(arr1_shape, arr2_shape, unidirectional=False):
def broadcast_together(arr1_shape, arr2_shape, unidirectional=False):

all_idx_dict, all_idx, a1_idx, a2_idx = {}, [], [], []

Expand Down Expand Up @@ -1523,9 +1523,9 @@ def get_idx(i):
all_idx_dict[get_idx(i)] = dim1
else:
if unidirectional:
raise SyntaxError(f"could not broadcast input array from shape {arr2_shape} into shape {arr1_shape}")
raise IndexError(f"could not broadcast input array from shape {arr2_shape} into shape {arr1_shape}")
else:
raise SyntaxError("operands could not be broadcast together with shapes {}, {}".format(
raise IndexError("operands could not be broadcast together with shapes {}, {}".format(
arr1_shape, arr2_shape))

def to_string(idx):
Expand All @@ -1543,7 +1543,7 @@ def _binop(sdfg: SDFG, state: SDFGState, op1: str, op2: str, opcode: str, opname
arr1 = sdfg.arrays[op1]
arr2 = sdfg.arrays[op2]

out_shape, all_idx_tup, all_idx, arr1_idx, arr2_idx = _broadcast_together(arr1.shape, arr2.shape)
out_shape, all_idx_tup, all_idx, arr1_idx, arr2_idx = broadcast_together(arr1.shape, arr2.shape)

name, _ = sdfg.add_temp_transient(out_shape, restype, arr1.storage)
state.add_mapped_tasklet("_%s_" % opname,
Expand Down Expand Up @@ -1928,7 +1928,7 @@ def _array_array_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le
left_shape = left_arr.shape
right_shape = right_arr.shape

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

# Fix for Scalars
if isinstance(left_arr, data.Scalar):
Expand Down Expand Up @@ -1996,7 +1996,7 @@ def _array_const_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le
if right_cast is not None:
tasklet_args[1] = "{c}({o})".format(c=str(right_cast).replace('::', '.'), o=tasklet_args[1])

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

out_operand, out_arr = sdfg.add_temp_transient(out_shape, result_type, storage)

Expand Down Expand Up @@ -2066,7 +2066,7 @@ def _array_sym_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, left
if right_cast is not None:
tasklet_args[1] = "{c}({o})".format(c=str(right_cast).replace('::', '.'), o=tasklet_args[1])

(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = _broadcast_together(left_shape, right_shape)
(out_shape, all_idx_dict, out_idx, left_idx, right_idx) = broadcast_together(left_shape, right_shape)

out_operand, out_arr = sdfg.add_temp_transient(out_shape, result_type, storage)

Expand Down
22 changes: 11 additions & 11 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ class InterstateEdge(object):
loop iterates).
"""

assignments = Property(dtype=dict,
desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')")
assignments = Property(dtype=dict, desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')")
condition = CodeProperty(desc="Transition condition", default=CodeBlock("1"))
guid = Property(dtype=str, allow_none=False)

Expand Down Expand Up @@ -214,7 +213,7 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == 'guid': # Skip ID
if k == 'guid': # Skip ID
continue
setattr(result, k, copy.deepcopy(v, memo))
return result
Expand Down Expand Up @@ -416,7 +415,11 @@ class SDFG(ControlFlowRegion):

name = Property(dtype=str, desc="Name of the SDFG")
arg_names = ListProperty(element_type=str, desc='Ordered argument names (used for calling conventions).')
constants_prop = Property(dtype=dict, default={}, desc="Compile-time constants")
constants_prop: Dict[str, Tuple[dt.Data, Any]] = Property(
dtype=dict,
default={},
desc='Compile-time constants. The dictionary maps between a constant name to '
'a tuple of its type and the actual constant data.')
_arrays = Property(dtype=NestedDict,
desc="Data descriptors for this SDFG",
to_json=_arrays_to_json,
Expand Down Expand Up @@ -463,7 +466,8 @@ class SDFG(ControlFlowRegion):
desc='Mapping between callback name and its original callback '
'(for when the same callback is used with a different signature)')

using_explicit_control_flow = Property(dtype=bool, default=False,
using_explicit_control_flow = Property(dtype=bool,
default=False,
desc="Whether the SDFG contains explicit control flow constructs")

def __init__(self,
Expand Down Expand Up @@ -612,9 +616,7 @@ def from_json(cls, json_obj, context=None):

ret = SDFG(name=attrs['name'], constants=constants_prop, parent=context['sdfg'])

dace.serialize.set_properties_from_json(ret,
json_obj,
ignore_properties={'constants_prop', 'name', 'hash'})
dace.serialize.set_properties_from_json(ret, json_obj, ignore_properties={'constants_prop', 'name', 'hash'})

nodelist = []
for n in nodes:
Expand Down Expand Up @@ -742,7 +744,6 @@ def replace_dict(self,
if symrepl:
symrepl = {k: v for k, v in symrepl.items() if str(k) != str(v)}


symrepl = symrepl or {
symbolic.pystr_to_symbolic(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v
for k, v in repldict.items()
Expand Down Expand Up @@ -2318,8 +2319,7 @@ def is_loaded(self) -> bool:
dll = cs.ReloadableDLL(binary_filename, self.name)
return dll.is_loaded()

def compile(self, output_file=None, validate=True,
return_program_handle=True) -> 'CompiledSDFG':
def compile(self, output_file=None, validate=True, return_program_handle=True) -> 'CompiledSDFG':
""" Compiles a runnable binary from this SDFG.

:param output_file: If not None, copies the output library file to
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,9 @@ def traverse_sdfg_with_defined_symbols(

:return: A generator that yields tuples of (state, node in state, currently-defined symbols)
"""
# Start with global symbols
# Start with global symbols and scalar constants
symbols = copy.copy(sdfg.symbols)
symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()})
symbols.update({k: desc.dtype for k, (desc, _) in sdfg.constants_prop.items() if isinstance(desc, dt.Scalar)})
for desc in sdfg.arrays.values():
symbols.update({str(s): s.dtype for s in desc.free_symbols})

Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
# Ensure that there is a mentioning of constants in either the array or symbol.
for const_name, (const_type, _) in sdfg.constants_prop.items():
if const_name in sdfg.arrays:
if const_type != sdfg.arrays[const_name].dtype:
if const_type.dtype != sdfg.arrays[const_name].dtype:
# This should actually be an error, but there is a lots of code that depends on it.
warnings.warn(f'Mismatch between constant and data descriptor of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".')
Expand Down
27 changes: 22 additions & 5 deletions dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,25 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle
self.out_mapping: Dict[str, Tuple[str, subsets.Range]] = {}
self.do_not_remove: Set[str] = set()

def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subset) -> subsets.Subset:
"""
Returns the requested range from a subscript node, which consists of the memlet subset composed with the
tasklet subset.

:param node: The subscript node.
:param memlet_subset: The memlet subset.
:return: The requested range.
"""
arrname, tasklet_slice = astutils.subscript_to_ast_slice(node)
arrname = arrname if arrname in self.arrays else None
if len(tasklet_slice) < len(memlet_subset):
# Unsqueeze all index dimensions from orig_subset into tasklet_subset
for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))):
if start == end:
tasklet_slice.insert(i, (None, None, None))
tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname))
return memlet_subset.compose(tasklet_subset)

def visit_Subscript(self, node: ast.Subscript) -> Any:
# Convert subscript to symbol name
node = self.generic_visit(node)
Expand All @@ -339,8 +358,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
new_name = dt.find_new_name(node_name, self.connector_names)
self.connector_names.add(new_name)

orig_subset = self.in_edges[node_name].subset
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1]))
subset = self._get_requested_range(node, self.in_edges[node_name].subset)
# Check if range can be collapsed
if _range_is_promotable(subset, self.defined):
self.in_mapping[new_name] = (node_name, subset)
Expand All @@ -351,8 +369,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
new_name = dt.find_new_name(node_name, self.connector_names)
self.connector_names.add(new_name)

orig_subset = self.out_edges[node_name].subset
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1]))
subset = self._get_requested_range(node, self.out_edges[node_name].subset)
# Check if range can be collapsed
if _range_is_promotable(subset, self.defined):
self.out_mapping[new_name] = (node_name, subset)
Expand Down Expand Up @@ -750,4 +767,4 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]:
return to_promote or None

def report(self, pass_retval: Set[str]) -> str:
return f'Promoted {len(pass_retval)} scalars to symbols.'
return f'Promoted {len(pass_retval)} scalars to symbols: {pass_retval}'
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]:
region.remove_branch(branch)
removed_branches += 1
# If the else branch remains, make sure it now has the new negate-all condition.
if new_else_cond is not None and region.branches[-1][0] is None:
if region.branches and new_else_cond is not None and region.branches[-1][0] is None:
region._branches[-1] = (new_else_cond, region._branches[-1][1])

if len(region.branches) == 0:
Expand Down
Loading
Loading