diff --git a/ast.py b/ast.py index d76ed29..ed81d2f 100644 --- a/ast.py +++ b/ast.py @@ -24,7 +24,7 @@ def _fixed(self, node): del node.matched[i] # This flattens rules with a single matched rule. - if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch) and node.name not in ('mbd', 'mrw'): # The last condition fixes small matrices like [1], [1,2], and [1|2]. + if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch) and node.name not in ('mbd', 'mrw', 'opb'): # The last condition fixes small matrices like [1], [1,2], and [1|2]. return self._fixed(node.matched[0]) # This makes left-associative operations left-associative. @@ -39,19 +39,13 @@ def _fixed(self, node): if node.name == 'mui': return self._fixed(RuleMatch('mul', [node.matched[0], Token('MUL', '*'), node.matched[1]])) - # This flattens matrix rows into parent matrix rows. - if node.name == 'mrw': - for i in range(len(node.matched) - 1, -1, -1): - if node.matched[i].name == 'mrw': - node.matched[i:] = node.matched[i].matched - return self._fixed(node) - - # This flattens matrix bodies into parent matrix bodies. - if node.name == 'mbd': - for i in range(len(node.matched) - 1, -1, -1): - if node.matched[i].name == 'mbd': - node.matched[i:] = node.matched[i].matched - return self._fixed(node) + # This flattens nested nodes into their parents if their parents are of the same type. + for tpe in ('mrw', 'mbd', 'opb'): + if node.name == tpe: + for i in range(len(node.matched) - 1, -1, -1): + if node.matched[i].name == tpe: + node.matched[i:] = node.matched[i].matched + return self._fixed(node) if isinstance(node, RuleMatch): for i in range(len(node.matched)): diff --git a/common.py b/common.py index 7509805..79371b9 100644 --- a/common.py +++ b/common.py @@ -45,6 +45,7 @@ def _str(self, node, depth=0) -> str: (r'identity', 'OPR'), (r'trnsform', 'OPR'), (r'rref', 'OPR'), + (r'solve', 'OPR'), (r'[a-zA-Z_]+', 'IDT'), (r'=', 'EQL'), (r'\+', 'ADD'), @@ -67,7 +68,7 @@ def _str(self, node, depth=0) -> str: class ImmutableIndexedDict: def __init__(self, data): - self._keys = tuple(item[0].lstrip('^') for item in data) + self._keys = tuple(item[0] for item in data if not item[0].startswith('^')) self._data = {key.lstrip('^'): values for key, values in data} # Caching indices cuts down on runtime. @@ -100,7 +101,8 @@ def key_at(self, i): ('mui', ('pow mul',)), ('mul', ('pow MUL mul',)), ('pow', ('opr POW pow',)), - ('opr', ('OPR LPA add RPA',)), + ('opr', ('OPR LPA opb RPA',)), + ('^opb', ('add CMA opb', 'add')), ('neg', ('ADD num', 'ADD opr')), ('var', ('IDT',)), ('num', ('NUM', 'LPA add RPA')), diff --git a/rules.py b/rules.py index 6476ad7..704af4d 100644 --- a/rules.py +++ b/rules.py @@ -4,7 +4,19 @@ from typing import List from common import Token -from vartypes import VariableValue, NumberValue, MatrixRowValue, MatrixValue, Value +from vartypes import VariableValue, NumberValue, MatrixRowValue, MatrixValue, Value, OperatorBodyValue + + +def flatten(l): + lst = [] + + for elem in l: + if isinstance(elem, list): + lst.append(elem[0]) + else: + lst.append(elem) + + return lst def var(_, tokens: List[Token]) -> VariableValue: @@ -23,6 +35,10 @@ def mbd(values: List[Value], _) -> MatrixValue: return MatrixValue(values) +def opb(values: List[Value], _) -> OperatorBodyValue: + return OperatorBodyValue(values) + + def add(operands: List[Value], operator: Token) -> Value: return {'+': operands[0].add, '-': operands[0].sub}[operator.value](*operands[1:]) @@ -36,7 +52,8 @@ def pow(operands: List[Value], _) -> Value: def opr(operands: List[Value], operator: Token) -> Value: - return getattr(operands[0], operator.value)(*operands[1:]) + args = operands[0].value + return getattr(args[0], operator.value)(*args[1:]) def neg(operands: List[Value], operator: Token) -> Value: @@ -49,6 +66,7 @@ def neg(operands: List[Value], operator: Token) -> Value: 'num': num, 'mrw': mrw, 'mbd': mbd, + 'opb': opb, } # The mapping for all other rules. diff --git a/vartypes.py b/vartypes.py index 38d0c4e..07c77f6 100644 --- a/vartypes.py +++ b/vartypes.py @@ -36,11 +36,10 @@ def div(self, other): def mod(self, other): raise EvaluationException('Operation not defined for ' + self.type) - def pow(self, other): raise EvaluationException('Operation not defined for ' + self.type) - + def sqrt(self): raise EvaluationException('Operation not defined for ' + self.type) @@ -50,6 +49,9 @@ def exp(self): def identity(self): raise EvaluationException('Operation not defined for ' + self.type) + def solve(self, other): + raise EvaluationException('Operation not defined for ' + self.type) + class VariableValue(Value): def __init__(self, data): @@ -224,6 +226,9 @@ def rref(self): def trnsform(self): return MatrixValue(self._rref()[1]) + def solve(self, other): + return NumberValue(2) + class MatrixRowValue(Value): def __init__(self, data): @@ -234,3 +239,10 @@ def __init__(self, data): else: self.value = data + + +class OperatorBodyValue(Value): + def __init__(self, args): + super().__init__() + + self.value = args \ No newline at end of file