Skip to content

Commit

Permalink
Added support for functions to return TupleTypes, which should be unp…
Browse files Browse the repository at this point in the history
…acked during variable assignment.
  • Loading branch information
nrubin29 committed Dec 24, 2017
1 parent 27d7f9d commit b5f86bf
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
15 changes: 11 additions & 4 deletions ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from common import RuleMatch, remove, left_assoc, Token
from rules import rule_value_map, rule_value_operation_map
from vartypes import TupleValue


class Ast:
Expand All @@ -24,7 +25,7 @@ def _fixed(self, node):
del node.matched[i]

# This flattens rules with a single matched rule.
if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch) and node.name not in ('mbd', 'mrw', 'opb'): # The last condition fixes small matrices like [1], [1,2], and [1|2].
if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch) and node.name not in ('mbd', 'mrw', 'opb', 'asb'): # The last condition fixes small matrices like [1], [1,2], and [1|2].
return self._fixed(node.matched[0])

# This makes left-associative operations left-associative.
Expand All @@ -40,7 +41,7 @@ def _fixed(self, node):
return self._fixed(RuleMatch('mul', [node.matched[0], Token('MUL', '*'), node.matched[1]]))

# This flattens nested nodes into their parents if their parents are of the same type.
for tpe in ('mrw', 'mbd', 'opb'):
for tpe in ('mrw', 'mbd', 'opb', 'asb'):
if node.name == tpe:
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name == tpe:
Expand All @@ -58,7 +59,7 @@ def evaluate(self, vrs: Dict[str, RuleMatch]):

def _evaluate(self, node, vrs: Dict[str, RuleMatch]):
if node.name == 'asn':
return {node.matched[0].value: node.matched[1]}
return {idt.value: (i, node.matched[1]) for i, idt in enumerate(node.matched[0].matched)}

for token in node.matched:
if isinstance(token, RuleMatch) and not token.value:
Expand All @@ -68,7 +69,13 @@ def _evaluate(self, node, vrs: Dict[str, RuleMatch]):
tokens = [token for token in node.matched if not isinstance(token, RuleMatch)]

if node.matched[0].name == 'IDT':
return self._evaluate(copy.deepcopy(vrs[node.matched[0].value]), vrs)
i, rule = vrs[node.matched[0].value]
result = self._evaluate(copy.deepcopy(rule), vrs)

if isinstance(result, TupleValue):
return result.value[i]

return result

elif node.name in rule_value_map:
return rule_value_map[node.name](values, tokens)
Expand Down
2 changes: 1 addition & 1 deletion calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _match(self, tokens: List[Token], target_rule: str):
return RuleMatch(target_rule, matched), remaining_tokens

idx = rules_map.index(target_rule)
if idx + 1 < len(rules_map):
if idx is not None and idx + 1 < len(rules_map):
return self._match(tokens, rules_map.key_at(idx + 1))

return None, None
5 changes: 3 additions & 2 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ def __len__(self):
return self._len

def index(self, key):
return self._key_indices[key]
return self._key_indices.get(key, None)

def key_at(self, i):
return self._keys[i]


rules_map = ImmutableIndexedDict((
('asn', ('IDT EQL add',)),
('asn', ('asb EQL add',)),
('^asb', ('IDT CMA asb', 'IDT')),
('add', ('mul ADD add', 'mui ADD add',)),
('mui', ('pow mul',)),
('mul', ('pow MUL mul',)),
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
calc = Calculator()

if len(sys.argv) > 1:
print(calc.evaluate(' '.join(sys.argv[1:])))
for line in ' '.join(sys.argv[1:]).split(';'):
print(calc.evaluate(line))

else:
while True:
Expand Down
2 changes: 1 addition & 1 deletion vartypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def qr(self):
for row in range(len(Q)):
Q[row][j] = val[row]

return TupleValue([Q, R])
return TupleValue([MatrixValue(Q), MatrixValue(R)])


class MatrixRowValue(Value):
Expand Down

0 comments on commit b5f86bf

Please sign in to comment.