Skip to content

Commit

Permalink
chore: cleanup from review
Browse files Browse the repository at this point in the history
  • Loading branch information
justindujardin committed Feb 5, 2024
1 parent 06f3341 commit f1b5bc2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 67 deletions.
36 changes: 4 additions & 32 deletions mathy_core/rules/multiplicative_inverse.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
from typing import Optional, cast
from typing import Optional

from ..expressions import (
AddExpression,
ConstantExpression,
DivideExpression,
EqualExpression,
MathExpression,
MultiplyExpression,
NegateExpression,
PowerExpression,
SubtractExpression,
VariableExpression,
)
from ..rule import BaseRule, ExpressionChangeRule

_OP_DIVISION_EXPRESSION = "division-expression"
_OP_DIVISION_VARIABLE = "division-variable"
_OP_DIVISION_COMPLEX_DENOMINATOR = "division-complex-denominator"
_OP_DIVISION_NEGATIVE_DENOMINATOR = "division-negative-denominator"


Expand All @@ -36,24 +29,12 @@ def get_type(self, node: MathExpression) -> Optional[str]:
Support different types of tree configurations based on the division operation:
- DivisionExpression is a division to be restated as multiplication by reciprocal
- DivisionVariable is a division by a variable
- DivisionComplexDenominator is a division by a complex expression
- DivisionNegativeDenominator is a division by a negative term
"""
is_division = isinstance(node, DivideExpression)
if not is_division:
return None

# Division by a variable (e.g., (2 + 3z) / z)
if isinstance(node.right, VariableExpression):
return _OP_DIVISION_VARIABLE

# Division where the denominator is a complex expression (e.g., (x^2 + 4x + 4) / (2x - 2))
if isinstance(node.right, AddExpression) or isinstance(
node.right, SubtractExpression
):
return _OP_DIVISION_COMPLEX_DENOMINATOR

# Division where the denominator is negative (e.g., (2 + 3z) / -z)
if isinstance(node.right, NegateExpression):
return _OP_DIVISION_NEGATIVE_DENOMINATOR
Expand All @@ -78,22 +59,13 @@ def apply_to(self, node: MathExpression) -> ExpressionChangeRule:
DivideExpression(ConstantExpression(1), node.right.clone()),
)

elif tree_type == _OP_DIVISION_VARIABLE:
# For division by a single variable, treat it the same as a general expression
reciprocal = DivideExpression(node.right.clone(), ConstantExpression(-1))
result = MultiplyExpression(node.left.clone(), reciprocal)

elif tree_type == _OP_DIVISION_COMPLEX_DENOMINATOR:
result = MultiplyExpression(
node.left.clone(),
DivideExpression(ConstantExpression(1), node.right.clone()),
)

elif tree_type == _OP_DIVISION_NEGATIVE_DENOMINATOR:
# For division by a negative denominator, negate the numerator and use the positive reciprocal
result = MultiplyExpression(
node.left.clone(),
DivideExpression(ConstantExpression(-1), node.right.get_child().clone()),
DivideExpression(
ConstantExpression(-1), node.right.get_child().clone()
),
)

else:
Expand Down
36 changes: 1 addition & 35 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from mathy_core import MathExpression
from mathy_core.parser import ExpressionParser
from mathy_core.rules import (
AssociativeSwapRule,
Expand All @@ -7,9 +6,9 @@
ConstantsSimplifyRule,
DistributiveFactorOutRule,
DistributiveMultiplyRule,
MultiplicativeInverseRule,
RestateSubtractionRule,
VariableMultiplyRule,
MultiplicativeInverseRule,
)
from mathy_core.testing import run_rule_tests

Expand Down Expand Up @@ -89,36 +88,3 @@ def test_rules_rule_can_apply_to():
]
for action in available_actions:
assert type(action.can_apply_to(expression)) == bool


def debug_expressions(one: MathExpression, two: MathExpression):
one_inputs = [f"{e.__class__.__name__}" for e in one.to_list()]
two_inputs = [f"{e.__class__.__name__}" for e in two.to_list()]
print("one: ", one.raw, one_inputs)
print("two: ", two.raw, two_inputs)


def test_rules_rule_restate_subtraction_corner_case_1():
parser = ExpressionParser()
expression = parser.parse("4x - 3y + 3x")

restate = RestateSubtractionRule()
dfo = DistributiveFactorOutRule()
commute = CommutativeSwapRule(preferred=False)

node = restate.find_node(expression)
assert node is not None, "should find node"
assert restate.can_apply_to(node), "should be able to apply"
change = restate.apply_to(node)
assert change.result is not None, "should get change"
assert change.result.get_root().raw == "4x + -3y + 3x"

change = commute.apply_to(change.result.get_root())
assert change.result is not None, "should get change"
node = dfo.find_node(change.result.get_root())
assert node is not None, "should find node"
assert dfo.can_apply_to(node), "should be able to apply"
change = dfo.apply_to(node)
assert change.result is not None, "should get change"
node = change.result.get_root()
assert node.raw == "(4 + 3) * x + -3y"

0 comments on commit f1b5bc2

Please sign in to comment.