Skip to content

Commit

Permalink
Simplified codemod and adjusted tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva committed Dec 12, 2023
1 parent 92e3212 commit f29b592
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 106 deletions.
115 changes: 30 additions & 85 deletions src/core_codemods/numpy_nan_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,80 +25,9 @@ class NumpyNanEquality(BaseCodemod, NameResolutionMixin):

np_nan = "numpy.nan"

def _is_np_nan(
self, left: NodeWithTrueName, right: NodeWithTrueName
) -> Optional[tuple[cst.CSTNode, cst.CSTNode]]:
if self.np_nan == left.name:
return left.node, right.node
if self.np_nan == right.name:
return right.node, left.node
return None

def _build_conjunction(
self, expressions: list[cst.BaseExpression], index: int
def _build_nan_comparison(
self, nan_node, node, preprend_not, lpar, rpar
) -> cst.BaseExpression:
if index == 0:
return expressions[0]
return cst.BooleanOperation(
left=self._build_conjunction(expressions, index - 1),
right=expressions[index],
operator=cst.And(),
)

def _build_new_comparison(self, expression_triples) -> cst.BaseExpression:
conjunction_expression = []
before: Optional[cst.Comparison] = None
for left, right, operator in expression_triples:
maybe_has_nan = self._is_np_nan(left, right)
if maybe_has_nan and isinstance(operator, cst.Equal | cst.NotEqual):
nan_node, node = maybe_has_nan
if before:
conjunction_expression.append(before)
before = None
conjunction_expression.append(
self._build_nan_comparison(
nan_node, node, isinstance(operator, cst.NotEqual)
)
)
else:
if before:
before.comparisons.append(
cst.ComparisonTarget(operator=operator, comparator=right.node)
)
else:
before = cst.Comparison(
left=left.node,
comparisons=[
cst.ComparisonTarget(
operator=operator, comparator=right.node
)
],
)
if before:
conjunction_expression.append(before)
return self._build_conjunction(
conjunction_expression, len(conjunction_expression) - 1
)

def _break_into_triples(self, comparison: cst.Comparison):
"""
Breaks a comparison expression into triples.
For example, the expression a == b == c is equivalent to a == b and b == c. This method will break it into [(a,b,==), (b,c,==)].
"""
left = NodeWithTrueName(
node=comparison.left, name=self.find_base_name(comparison.left)
)
# the first always exists
ct = comparison.comparisons[0]
right = NodeWithTrueName(ct.comparator, self.find_base_name(ct.comparator))
triples = [(left, right, ct.operator)]
for ct in comparison.comparisons[1:]:
left = right
right = NodeWithTrueName(ct.comparator, self.find_base_name(ct.comparator))
triples.append((left, right, ct.operator))
return triples

def _build_nan_comparison(self, nan_node, node, preprend_not):
if maybe_numpy_alias := self.find_alias_for_import_in_node("numpy", nan_node):
call = cst.parse_expression(f"{maybe_numpy_alias}.isnan()")
else:
Expand All @@ -107,21 +36,37 @@ def _build_nan_comparison(self, nan_node, node, preprend_not):
call = cst.parse_expression("numpy.isnan()")
call = call.with_changes(args=[cst.Arg(value=node)])
if preprend_not:
return UnaryOperation(operator=cst.Not(), expression=call)
return call
return UnaryOperation(
operator=cst.Not(), expression=call, lpar=lpar, rpar=rpar
)
return call.with_changes(lpar=lpar, rpar=rpar)

def _is_np_nan_eq(self, left: cst.BaseExpression, target: cst.ComparisonTarget):
if isinstance(target.operator, cst.Equal | cst.NotEqual):
right = target.comparator
left_name = self.find_base_name(left)
right_name = self.find_base_name(right)
if self.np_nan == left_name:
return left, right
if self.np_nan == right_name:
return right, left
return None

def leave_Comparison(
self, original_node: cst.Comparison, updated_node: cst.Comparison
) -> cst.BaseExpression:
if self.filter_by_path_includes_or_excludes(self.node_position(original_node)):
comparison_triples = self._break_into_triples(original_node)
for left, right, operator in comparison_triples:
if matchers.matches(
operator, matchers.Equal() | matchers.NotEqual()
) and self.np_nan in (
left.name,
right.name,
):
self.report_change(original_node)
return self._build_new_comparison(comparison_triples)
match original_node:
case cst.Comparison(comparisons=[cst.ComparisonTarget() as target]):
maybe_nan_eq = self._is_np_nan_eq(original_node.left, target)
if maybe_nan_eq:
nan_node, node = maybe_nan_eq
self.report_change(original_node)
return self._build_nan_comparison(
nan_node,
node,
isinstance(target.operator, cst.NotEqual),
lpar=original_node.lpar,
rpar=original_node.rpar,
)
return updated_node
65 changes: 44 additions & 21 deletions tests/codemods/test_numpy_nan_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,77 +37,100 @@ def test_simple_inequality(self, tmpdir):
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_from_numpy(self, tmpdir):
def test_simple_inequality_2(self, tmpdir):
input_code = """\
from numpy import nan
if a == nan:
import numpy
if not (a != numpy.nan):
pass
"""
expected = """\
import numpy
if numpy.isnan(a):
if not (not numpy.isnan(a)):
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_simple_left(self, tmpdir):
def test_simple_parenthesis(self, tmpdir):
input_code = """\
import numpy
if numpy.nan == a:
if ( a == numpy.nan ):
pass
"""
expected = """\
import numpy
if numpy.isnan(a):
if ( numpy.isnan(a) ):
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_alias(self, tmpdir):
def test_conjunction(self, tmpdir):
input_code = """\
import numpy as np
if a == np.nan:
import numpy
if a != numpy.nan and b!= numpy.nan:
pass
"""
expected = """\
import numpy as np
if np.isnan(a):
import numpy
if not numpy.isnan(a) and not numpy.isnan(b):
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 2

def test_from_numpy(self, tmpdir):
input_code = """\
from numpy import nan
if a == nan:
pass
"""
expected = """\
import numpy
if numpy.isnan(a):
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple_comparisons(self, tmpdir):
def test_simple_left(self, tmpdir):
input_code = """\
import numpy as np
if a == np.nan == b:
import numpy
if numpy.nan == a:
pass
"""
expected = """\
import numpy as np
if np.isnan(a) and np.isnan(b):
import numpy
if numpy.isnan(a):
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple_comparisons_preserve_longest(self, tmpdir):
def test_alias(self, tmpdir):
input_code = """\
import numpy as np
if a == np.nan == b == c == d <= e:
if a == np.nan:
pass
"""
expected = """\
import numpy as np
if np.isnan(a) and np.isnan(b) and b == c == d <= e:
if np.isnan(a):
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(expected))
assert len(self.file_context.codemod_changes) == 1

def test_multiple_comparisons(self, tmpdir):
input_code = """\
import numpy as np
if a == np.nan == b == c == d <= e:
pass
"""
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))
assert len(self.file_context.codemod_changes) == 0

def test_not_numpy(self, tmpdir):
input_code = """\
import not_numpy as np
Expand Down

0 comments on commit f29b592

Please sign in to comment.