Skip to content

Commit

Permalink
Added support for multiple arguments to a function. The function is c…
Browse files Browse the repository at this point in the history
…alled on the first argument with all of the other arguments supplied as parameters.
  • Loading branch information
nrubin29 committed Nov 25, 2017
1 parent 31577de commit b3f637c
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 20 deletions.
22 changes: 8 additions & 14 deletions ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _fixed(self, node):
del node.matched[i]

# This flattens rules with a single matched rule.
if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch) and node.name not in ('mbd', 'mrw'): # The last condition fixes small matrices like [1], [1,2], and [1|2].
if len(node.matched) is 1 and isinstance(node.matched[0], RuleMatch) and node.name not in ('mbd', 'mrw', 'opb'): # The last condition fixes small matrices like [1], [1,2], and [1|2].
return self._fixed(node.matched[0])

# This makes left-associative operations left-associative.
Expand All @@ -39,19 +39,13 @@ def _fixed(self, node):
if node.name == 'mui':
return self._fixed(RuleMatch('mul', [node.matched[0], Token('MUL', '*'), node.matched[1]]))

# This flattens matrix rows into parent matrix rows.
if node.name == 'mrw':
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name == 'mrw':
node.matched[i:] = node.matched[i].matched
return self._fixed(node)

# This flattens matrix bodies into parent matrix bodies.
if node.name == 'mbd':
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name == 'mbd':
node.matched[i:] = node.matched[i].matched
return self._fixed(node)
# This flattens nested nodes into their parents if their parents are of the same type.
for tpe in ('mrw', 'mbd', 'opb'):
if node.name == tpe:
for i in range(len(node.matched) - 1, -1, -1):
if node.matched[i].name == tpe:
node.matched[i:] = node.matched[i].matched
return self._fixed(node)

if isinstance(node, RuleMatch):
for i in range(len(node.matched)):
Expand Down
6 changes: 4 additions & 2 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _str(self, node, depth=0) -> str:
(r'identity', 'OPR'),
(r'trnsform', 'OPR'),
(r'rref', 'OPR'),
(r'solve', 'OPR'),
(r'[a-zA-Z_]+', 'IDT'),
(r'=', 'EQL'),
(r'\+', 'ADD'),
Expand All @@ -67,7 +68,7 @@ def _str(self, node, depth=0) -> str:

class ImmutableIndexedDict:
def __init__(self, data):
self._keys = tuple(item[0].lstrip('^') for item in data)
self._keys = tuple(item[0] for item in data if not item[0].startswith('^'))
self._data = {key.lstrip('^'): values for key, values in data}

# Caching indices cuts down on runtime.
Expand Down Expand Up @@ -100,7 +101,8 @@ def key_at(self, i):
('mui', ('pow mul',)),
('mul', ('pow MUL mul',)),
('pow', ('opr POW pow',)),
('opr', ('OPR LPA add RPA',)),
('opr', ('OPR LPA opb RPA',)),
('^opb', ('add CMA opb', 'add')),
('neg', ('ADD num', 'ADD opr')),
('var', ('IDT',)),
('num', ('NUM', 'LPA add RPA')),
Expand Down
22 changes: 20 additions & 2 deletions rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
from typing import List

from common import Token
from vartypes import VariableValue, NumberValue, MatrixRowValue, MatrixValue, Value
from vartypes import VariableValue, NumberValue, MatrixRowValue, MatrixValue, Value, OperatorBodyValue


def flatten(l):
lst = []

for elem in l:
if isinstance(elem, list):
lst.append(elem[0])
else:
lst.append(elem)

return lst


def var(_, tokens: List[Token]) -> VariableValue:
Expand All @@ -23,6 +35,10 @@ def mbd(values: List[Value], _) -> MatrixValue:
return MatrixValue(values)


def opb(values: List[Value], _) -> OperatorBodyValue:
return OperatorBodyValue(values)


def add(operands: List[Value], operator: Token) -> Value:
return {'+': operands[0].add, '-': operands[0].sub}[operator.value](*operands[1:])

Expand All @@ -36,7 +52,8 @@ def pow(operands: List[Value], _) -> Value:


def opr(operands: List[Value], operator: Token) -> Value:
return getattr(operands[0], operator.value)(*operands[1:])
args = operands[0].value
return getattr(args[0], operator.value)(*args[1:])


def neg(operands: List[Value], operator: Token) -> Value:
Expand All @@ -49,6 +66,7 @@ def neg(operands: List[Value], operator: Token) -> Value:
'num': num,
'mrw': mrw,
'mbd': mbd,
'opb': opb,
}

# The mapping for all other rules.
Expand Down
16 changes: 14 additions & 2 deletions vartypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def div(self, other):

def mod(self, other):
raise EvaluationException('Operation not defined for ' + self.type)


def pow(self, other):
raise EvaluationException('Operation not defined for ' + self.type)

def sqrt(self):
raise EvaluationException('Operation not defined for ' + self.type)

Expand All @@ -50,6 +49,9 @@ def exp(self):
def identity(self):
raise EvaluationException('Operation not defined for ' + self.type)

def solve(self, other):
raise EvaluationException('Operation not defined for ' + self.type)


class VariableValue(Value):
def __init__(self, data):
Expand Down Expand Up @@ -224,6 +226,9 @@ def rref(self):
def trnsform(self):
return MatrixValue(self._rref()[1])

def solve(self, other):
return NumberValue(2)


class MatrixRowValue(Value):
def __init__(self, data):
Expand All @@ -234,3 +239,10 @@ def __init__(self, data):

else:
self.value = data


class OperatorBodyValue(Value):
def __init__(self, args):
super().__init__()

self.value = args

0 comments on commit b3f637c

Please sign in to comment.