diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 253a2d251..1a1a6736b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,8 +11,7 @@ repos: src/core_codemods/docs/.*| src/codemodder/dependency.py | integration_tests/.*| - tests/codemods/test_remove_debug_breakpoint.py | - tests/test_codemodder.py + tests/.* )$ - id: check-added-large-files - repo: https://github.com/psf/black diff --git a/integration_tests/test_combine_startswith_endswith.py b/integration_tests/test_combine_startswith_endswith.py new file mode 100644 index 000000000..9ad639f61 --- /dev/null +++ b/integration_tests/test_combine_startswith_endswith.py @@ -0,0 +1,16 @@ +from core_codemods.combine_startswith_endswith import CombineStartswithEndswith +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestCombineStartswithEndswith(BaseIntegrationTest): + codemod = CombineStartswithEndswith + code_path = "tests/samples/combine_startswith_endswith.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, [(1, 'if x.startswith(("foo", "bar")):\n')] + ) + expected_diff = '--- \n+++ \n@@ -1,3 +1,3 @@\n x = \'foo\'\n-if x.startswith("foo") or x.startswith("bar"):\n+if x.startswith(("foo", "bar")):\n print("Yes")\n' + expected_line_change = "2" + change_description = CombineStartswithEndswith.CHANGE_DESCRIPTION diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index 0c50b66dc..7d53f97a8 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -194,6 +194,10 @@ class DocMetadata: importance="Medium", guidance_explained="Breakpoints are generally used only for debugging and can easily be forgotten before deploying code.", ), + "combine-startswith-endswith": DocMetadata( + importance="Low", + guidance_explained="Simplifying expressions involving `startswith` or `endswith` calls is safe.", + ), } diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 6302ed14a..d5c68292d 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -42,7 +42,7 @@ from .subprocess_shell_false import SubprocessShellFalse from .remove_module_global import RemoveModuleGlobal from .remove_debug_breakpoint import RemoveDebugBreakpoint - +from .combine_startswith_endswith import CombineStartswithEndswith registry = CodemodCollection( origin="pixee", @@ -91,5 +91,6 @@ LiteralOrNewObjectIdentity, RemoveModuleGlobal, RemoveDebugBreakpoint, + CombineStartswithEndswith, ], ) diff --git a/src/core_codemods/combine_startswith_endswith.py b/src/core_codemods/combine_startswith_endswith.py new file mode 100644 index 000000000..070497bc1 --- /dev/null +++ b/src/core_codemods/combine_startswith_endswith.py @@ -0,0 +1,63 @@ +import libcst as cst +from libcst import matchers as m +from codemodder.codemods.api import BaseCodemod, ReviewGuidance +from codemodder.codemods.utils_mixin import NameResolutionMixin + + +class CombineStartswithEndswith(BaseCodemod, NameResolutionMixin): + NAME = "combine-startswith-endswith" + SUMMARY = "Simplify Boolean Expressions Using `startswith` and `endswith`" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW + DESCRIPTION = "Use tuple of matches instead of boolean expression" + REFERENCES: list = [] + + def leave_BooleanOperation( + self, original_node: cst.BooleanOperation, updated_node: cst.BooleanOperation + ) -> cst.CSTNode: + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return original_node + + if self.matches_startswith_endswith_or_pattern(original_node): + left_call = cst.ensure_type(updated_node.left, cst.Call) + right_call = cst.ensure_type(updated_node.right, cst.Call) + + self.report_change(original_node) + + new_arg = cst.Arg( + value=cst.Tuple( + elements=[ + cst.Element(value=left_call.args[0].value), + cst.Element(value=right_call.args[0].value), + ] + ) + ) + + return cst.Call(func=left_call.func, args=[new_arg]) + + return updated_node + + def matches_startswith_endswith_or_pattern( + self, node: cst.BooleanOperation + ) -> bool: + # Match the pattern: x.startswith("...") or x.startswith("...") + # and the same but with endswith + args = [m.Arg(value=m.SimpleString())] + startswith = m.Call( + func=m.Attribute(value=m.Name(), attr=m.Name("startswith")), + args=args, + ) + endswith = m.Call( + func=m.Attribute(value=m.Name(), attr=m.Name("endswith")), + args=args, + ) + startswith_or = m.BooleanOperation( + left=startswith, operator=m.Or(), right=startswith + ) + endswith_or = m.BooleanOperation(left=endswith, operator=m.Or(), right=endswith) + + return ( + m.matches(node, startswith_or | endswith_or) + and node.left.func.value.value == node.right.func.value.value + ) diff --git a/src/core_codemods/docs/pixee_python_combine-startswith-endswith.md b/src/core_codemods/docs/pixee_python_combine-startswith-endswith.md new file mode 100644 index 000000000..ca5e51fc9 --- /dev/null +++ b/src/core_codemods/docs/pixee_python_combine-startswith-endswith.md @@ -0,0 +1,12 @@ +Many developers are not necessarily aware that the `startswith` and `endswith` methods of `str` objects can accept a tuple of strings to match. This means that there is a lot of code that uses boolean expressions such as `x.startswith('foo') or x.startswith('bar')` instead of the simpler expression `x.startswith(('foo', 'bar'))`. + +This codemod simplifies the boolean expressions where possible which leads to cleaner and more concise code. + +The changes from this codemod look like this: + +```diff + x = 'foo' +- if x.startswith("foo") or x.startswith("bar"): ++ if x.startswith(("foo", "bar")): + ... +``` diff --git a/src/core_codemods/remove_debug_breakpoint.py b/src/core_codemods/remove_debug_breakpoint.py index 21c324fb6..104dcb07a 100644 --- a/src/core_codemods/remove_debug_breakpoint.py +++ b/src/core_codemods/remove_debug_breakpoint.py @@ -6,14 +6,19 @@ class RemoveDebugBreakpoint(BaseCodemod, NameResolutionMixin, AncestorPatternsMixin): NAME = "remove-debug-breakpoint" - SUMMARY = "Remove Breakpoint" + SUMMARY = "Remove Calls to `builtin` `breakpoint` and `pdb.set_trace" REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW - DESCRIPTION = "Remove calls to builtin `breakpoint` or `pdb.set_trace." + DESCRIPTION = "Remove breakpoint call" REFERENCES: list = [] def leave_Expr( self, original_node: cst.Expr, _ ) -> Union[cst.Expr, cst.RemovalSentinel]: + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return original_node + match call_node := original_node.value: case cst.Call(): if self.find_base_name( diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py new file mode 100644 index 000000000..19305f9d3 --- /dev/null +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -0,0 +1,40 @@ +import pytest +from tests.codemods.base_codemod_test import BaseCodemodTest +from core_codemods.combine_startswith_endswith import CombineStartswithEndswith + +each_func = pytest.mark.parametrize("func", ["startswith", "endswith"]) + + +class TestCombineStartswithEndswith(BaseCodemodTest): + codemod = CombineStartswithEndswith + + def test_name(self): + assert self.codemod.name() == "combine-startswith-endswith" + + @each_func + def test_combine(self, tmpdir, func): + input_code = f"""\ + x = "foo" + x.{func}("foo") or x.{func}("f") + """ + expected = f"""\ + x = "foo" + x.{func}(("foo", "f")) + """ + self.run_and_assert(tmpdir, input_code, expected) + assert len(self.file_context.codemod_changes) == 1 + + @pytest.mark.parametrize( + "code", + [ + "x.startswith('foo')", + "x.startswith(('f', 'foo'))", + "x.startswith('foo') and x.startswith('f')", + "x.startswith('foo') and x.startswith('f') or True", + "x.startswith('foo') or x.endswith('f')", + "x.startswith('foo') or y.startswith('f')", + ], + ) + def test_no_change(self, tmpdir, code): + self.run_and_assert(tmpdir, code, code) + assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/codemods/test_remove_debug_breakpoint.py b/tests/codemods/test_remove_debug_breakpoint.py index 5834cadd7..27e10b834 100644 --- a/tests/codemods/test_remove_debug_breakpoint.py +++ b/tests/codemods/test_remove_debug_breakpoint.py @@ -1,6 +1,5 @@ from core_codemods.remove_debug_breakpoint import RemoveDebugBreakpoint from tests.codemods.base_codemod_test import BaseCodemodTest -from textwrap import dedent class TestRemoveDebugBreakpoint(BaseCodemodTest): @@ -21,7 +20,7 @@ def something(): var = 1 something() """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 def test_builtin_breakpoint_multiple_statements(self, tmpdir): @@ -37,7 +36,7 @@ def something(): print(var); something() """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 def test_inline_pdb(self, tmpdir): @@ -52,7 +51,7 @@ def something(): var = 1 something() """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 def test_pdb_import(self, tmpdir): @@ -68,7 +67,7 @@ def something(): var = 1 something() """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 def test_pdb_from_import(self, tmpdir): @@ -84,5 +83,5 @@ def something(): var = 1 something() """ - self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + self.run_and_assert(tmpdir, input_code, expected) assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/samples/combine_startswith_endswith.py b/tests/samples/combine_startswith_endswith.py new file mode 100644 index 000000000..edf152a59 --- /dev/null +++ b/tests/samples/combine_startswith_endswith.py @@ -0,0 +1,3 @@ +x = 'foo' +if x.startswith("foo") or x.startswith("bar"): + print("Yes")