diff --git a/mathy_core/rules/multiplicative_inverse.py b/mathy_core/rules/multiplicative_inverse.py index 2bd3186..8942d1f 100644 --- a/mathy_core/rules/multiplicative_inverse.py +++ b/mathy_core/rules/multiplicative_inverse.py @@ -1,22 +1,15 @@ -from typing import Optional, cast +from typing import Optional from ..expressions import ( - AddExpression, ConstantExpression, DivideExpression, - EqualExpression, MathExpression, MultiplyExpression, NegateExpression, - PowerExpression, - SubtractExpression, - VariableExpression, ) from ..rule import BaseRule, ExpressionChangeRule _OP_DIVISION_EXPRESSION = "division-expression" -_OP_DIVISION_VARIABLE = "division-variable" -_OP_DIVISION_COMPLEX_DENOMINATOR = "division-complex-denominator" _OP_DIVISION_NEGATIVE_DENOMINATOR = "division-negative-denominator" @@ -36,24 +29,12 @@ def get_type(self, node: MathExpression) -> Optional[str]: Support different types of tree configurations based on the division operation: - DivisionExpression is a division to be restated as multiplication by reciprocal - - DivisionVariable is a division by a variable - - DivisionComplexDenominator is a division by a complex expression - DivisionNegativeDenominator is a division by a negative term """ is_division = isinstance(node, DivideExpression) if not is_division: return None - # Division by a variable (e.g., (2 + 3z) / z) - if isinstance(node.right, VariableExpression): - return _OP_DIVISION_VARIABLE - - # Division where the denominator is a complex expression (e.g., (x^2 + 4x + 4) / (2x - 2)) - if isinstance(node.right, AddExpression) or isinstance( - node.right, SubtractExpression - ): - return _OP_DIVISION_COMPLEX_DENOMINATOR - # Division where the denominator is negative (e.g., (2 + 3z) / -z) if isinstance(node.right, NegateExpression): return _OP_DIVISION_NEGATIVE_DENOMINATOR @@ -78,22 +59,13 @@ def apply_to(self, node: MathExpression) -> ExpressionChangeRule: DivideExpression(ConstantExpression(1), node.right.clone()), ) - elif tree_type == _OP_DIVISION_VARIABLE: - # For division by a single variable, treat it the same as a general expression - reciprocal = DivideExpression(node.right.clone(), ConstantExpression(-1)) - result = MultiplyExpression(node.left.clone(), reciprocal) - - elif tree_type == _OP_DIVISION_COMPLEX_DENOMINATOR: - result = MultiplyExpression( - node.left.clone(), - DivideExpression(ConstantExpression(1), node.right.clone()), - ) - elif tree_type == _OP_DIVISION_NEGATIVE_DENOMINATOR: # For division by a negative denominator, negate the numerator and use the positive reciprocal result = MultiplyExpression( node.left.clone(), - DivideExpression(ConstantExpression(-1), node.right.get_child().clone()), + DivideExpression( + ConstantExpression(-1), node.right.get_child().clone() + ), ) else: diff --git a/tests/test_rules.py b/tests/test_rules.py index 011bcb4..14b18df 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -1,4 +1,3 @@ -from mathy_core import MathExpression from mathy_core.parser import ExpressionParser from mathy_core.rules import ( AssociativeSwapRule, @@ -7,9 +6,9 @@ ConstantsSimplifyRule, DistributiveFactorOutRule, DistributiveMultiplyRule, + MultiplicativeInverseRule, RestateSubtractionRule, VariableMultiplyRule, - MultiplicativeInverseRule, ) from mathy_core.testing import run_rule_tests @@ -89,36 +88,3 @@ def test_rules_rule_can_apply_to(): ] for action in available_actions: assert type(action.can_apply_to(expression)) == bool - - -def debug_expressions(one: MathExpression, two: MathExpression): - one_inputs = [f"{e.__class__.__name__}" for e in one.to_list()] - two_inputs = [f"{e.__class__.__name__}" for e in two.to_list()] - print("one: ", one.raw, one_inputs) - print("two: ", two.raw, two_inputs) - - -def test_rules_rule_restate_subtraction_corner_case_1(): - parser = ExpressionParser() - expression = parser.parse("4x - 3y + 3x") - - restate = RestateSubtractionRule() - dfo = DistributiveFactorOutRule() - commute = CommutativeSwapRule(preferred=False) - - node = restate.find_node(expression) - assert node is not None, "should find node" - assert restate.can_apply_to(node), "should be able to apply" - change = restate.apply_to(node) - assert change.result is not None, "should get change" - assert change.result.get_root().raw == "4x + -3y + 3x" - - change = commute.apply_to(change.result.get_root()) - assert change.result is not None, "should get change" - node = dfo.find_node(change.result.get_root()) - assert node is not None, "should find node" - assert dfo.can_apply_to(node), "should be able to apply" - change = dfo.apply_to(node) - assert change.result is not None, "should get change" - node = change.result.get_root() - assert node.raw == "(4 + 3) * x + -3y"