Skip to content

Commit

Permalink
Corrections to MapAt for Muti-dimensions...
Browse files Browse the repository at this point in the history
The following changes implement pretty much all of what Combinatorical
V0.09 needs. We have much better compliance to WMA MapAt[]
  • Loading branch information
rocky committed Dec 21, 2024
1 parent b052251 commit 58c2267
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 24 deletions.
10 changes: 9 additions & 1 deletion mathics/builtin/functional/apply_fns_to_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,17 @@ class MapAt(Builtin):
>> MapAt[f, <|"a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4|>, -3]
= {a -> 1, b -> f[2], c -> 3, d -> 4}
Use the operator form of MapAt:
Use the operator form of 'MapAt':
>> MapAt[f, 1][{a, b, c, d}]
= {f[a], b, c, d}
A vector position of a multi-dimensional array can be supplied:
>> MapAt[1&, {{0, 0}, {0, 0}}, {1, 1}]
= {{1, 0}, {0, 0}}
Lists of vector position of a multi-dimensional array can be supplied too:
>> MapAt[1&, {{0, 0}, {0, 0}}, {{1, 2}, {2, 1}}]
= {{0, 1}, {1, 0}}
"""

rules = {
Expand Down
137 changes: 114 additions & 23 deletions mathics/eval/functional/apply_fns_to_lists.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,148 @@
"""
Evaluation routines for mathics.builtin.functional.appy_fns_to_lists
"""
from mathics.core.atoms import Integer

from typing import Iterable, Optional, Union

from mathics.core.atoms import Integer, Integer1
from mathics.core.element import BaseElement
from mathics.core.evaluation import Evaluation
from mathics.core.exceptions import PartRangeError
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.symbols import Symbol
from mathics.core.symbols import SymbolTrue
from mathics.core.systemsymbols import SymbolMapAt, SymbolRule
from mathics.eval.testing_expressions import eval_ArrayQ


def eval_MapAt(f, expr, args, evaluation: Evaluation):
m = len(expr.elements)
new_elements = list(expr.elements)
def eval_MapAt(
f: BaseElement, expr: ListExpression, args, evaluation: Evaluation
) -> Optional[ListExpression]:
"""
evaluation routine for MapAt[]
"""

def map_at_replace_one(i: int):
def map_at_replace_one(elements: Iterable, index: ListExpression, i: int) -> list:
"""
Perform a single MapAt[] replacement for elements[i].
Global "f" is used to compute the replacement value,
and if there is an error, "expr" is used in the error message.
"""
m = len(elements)
if 1 <= i <= m:
j = i - 1
elif -m <= i <= -1:
j = m + i
else:
evaluation.message("MapAt", "partw", ListExpression(Integer(i)), expr)
evaluation.message("MapAt", "partw", index, expr)
raise PartRangeError
replace_element = new_elements[j]
if hasattr(replace_element, "head") and replace_element.head is Symbol(
"System`Rule"
):
new_elements = list(elements)
replace_element = elements[j]
if hasattr(replace_element, "head") and replace_element.head is SymbolRule:
new_elements[j] = Expression(
SymbolRule,
replace_element.elements[0],
Expression(f, replace_element.elements[1]),
)
else:
new_elements[j] = Expression(f, replace_element)
return new_elements

def map_at_replace_level(
elements: list,
remaining_indices: Union[tuple, Integer],
orig_index: ListExpression,
) -> list:
"""Recursive routine to replace remaining indices inside elements which is a portion at some level of
expr.elements.
``elements`` holds the ListExpression list for the portion of the
top-level ListExpression where we need to still index into.
Some part of the original ListExpression may have already been traversed.
``remaining_indices`` gives the list of indices we still have to index into,
these will be a suffix ``orig_index``.
``orig_index`` is used for error reporting.
"""
if isinstance(remaining_indices, Integer):
remaining_indices = (remaining_indices,)

i_expr = remaining_indices[0]

if not isinstance(i_expr, Integer):
evaluation.message(
"MapAt", "psl", args, Expression(SymbolMapAt, f, expr, args)
)
raise PartRangeError
i = i_expr.value
m = len(elements)
if 1 <= i <= m:
j = i - 1
elif -m <= i <= -1:
j = m + i
else:
evaluation.message("MapAt", "partw", orig_index, expr)
raise PartRangeError

next_level_elements = elements[j]
if len(remaining_indices) == 1:
if isinstance(next_level_elements, ListExpression):
# TODO: Check type of [0].value
new_list_expr = map_at_replace_one(
next_level_elements.elements, orig_index, remaining_indices[0].value
)
else:
new_list_expr = map_at_replace_one(
elements, orig_index, remaining_indices[0].value
)
return new_list_expr
elif not isinstance(next_level_elements, ListExpression):
# We have run out of nesting for indexing.
evaluation.message("MapAt", "partw", orig_index, expr)
raise PartRangeError

else:
# len(remaining_indices) > 1 and isinstance(next_level_elements, ListExpression)
elements[j] = ListExpression(
*map_at_replace_level(
list(next_level_elements.elements),
remaining_indices[1:],
orig_index,
)
)
return elements

try:
if isinstance(args, Integer):
map_at_replace_one(args.value)
return ListExpression(*new_elements)
elif isinstance(args, Expression):
new_list_expr = map_at_replace_one(
list(expr.elements), ListExpression(args), args.value
)
# Is args a vector?
elif eval_ArrayQ(args, Integer1, None, evaluation) is SymbolTrue:
new_list_expr = map_at_replace_level(
list(expr.elements), args.elements, args
)
# Until we can find what's causing Expression, SymbolConstant List from being a ListExpression,
# we will include Expression below...
elif isinstance(args, (ListExpression, Expression)):
new_list_expr = list(expr.elements)
for item in args.elements:
# Get value for arg in expr.elemnts
# Replace value
if (
isinstance(item, Expression)
and len(item.elements) == 1
and isinstance(item.elements[0], Integer)
):
map_at_replace_one(item.elements[0].value)
return ListExpression(*new_elements)
if isinstance(item, ListExpression):
new_list_expr = map_at_replace_level(
new_list_expr, item.elements, item
)
else:
new_list_expr = map_at_replace_level(
new_list_expr, item.elements, ListExpression(item)
)
return ListExpression(*new_list_expr)
else:
evaluation.message(
"MapAt", "psl", args, Expression(SymbolMapAt, f, expr, args)
)
raise PartRangeError
return ListExpression(*new_list_expr)
except PartRangeError:
# A message was issued where the error occurred
return

0 comments on commit 58c2267

Please sign in to comment.