diff --git a/src/core_codemods/combine_startswith_endswith.py b/src/core_codemods/combine_startswith_endswith.py index 8f6e3a364..eed14264d 100644 --- a/src/core_codemods/combine_startswith_endswith.py +++ b/src/core_codemods/combine_startswith_endswith.py @@ -14,6 +14,11 @@ class CombineStartswithEndswith(BaseCodemod, NameResolutionMixin): def leave_BooleanOperation( self, original_node: cst.BooleanOperation, updated_node: cst.BooleanOperation ) -> cst.CSTNode: + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return original_node + if self.matches_startswith_endswith_or_pattern(original_node): left_call = cst.ensure_type(updated_node.left, cst.Call) right_call = cst.ensure_type(updated_node.right, cst.Call) @@ -38,12 +43,21 @@ def matches_startswith_endswith_or_pattern( ) -> bool: # Match the pattern: x.startswith("...") or x.startswith("...") # and the same but with endswith - call = m.Call( - func=m.Attribute( - value=m.Name(), attr=m.Name("startswith") | m.Name("endswith") - ), - args=[m.Arg(value=m.SimpleString())], + args = [m.Arg(value=m.SimpleString())] + startswith = m.Call( + func=m.Attribute(value=m.Name(), attr=m.Name("startswith")), + args=args, + ) + endswith = m.Call( + func=m.Attribute(value=m.Name(), attr=m.Name("endswith")), + args=args, ) - return m.matches( - node, m.BooleanOperation(left=call, operator=m.Or(), right=call) + startswith_or = m.BooleanOperation( + left=startswith, operator=m.Or(), right=startswith + ) + endswith_or = m.BooleanOperation(left=endswith, operator=m.Or(), right=endswith) + + return ( + m.matches(node, startswith_or | endswith_or) + and node.left.func.value.value == node.right.func.value.value ) diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py index 0e088996c..19305f9d3 100644 --- a/tests/codemods/test_combine_startswith_endswith.py +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -1,7 +1,6 @@ import pytest from tests.codemods.base_codemod_test import BaseCodemodTest from core_codemods.combine_startswith_endswith import CombineStartswithEndswith -from textwrap import dedent each_func = pytest.mark.parametrize("func", ["startswith", "endswith"]) @@ -22,7 +21,7 @@ def test_combine(self, tmpdir, func): x = "foo" x.{func}(("foo", "f")) """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 @pytest.mark.parametrize( @@ -32,6 +31,8 @@ def test_combine(self, tmpdir, func): "x.startswith(('f', 'foo'))", "x.startswith('foo') and x.startswith('f')", "x.startswith('foo') and x.startswith('f') or True", + "x.startswith('foo') or x.endswith('f')", + "x.startswith('foo') or y.startswith('f')", ], ) def test_no_change(self, tmpdir, code):