From 3f328a6f4987ce37cafc1fbf5c4b3f46a276a50d Mon Sep 17 00:00:00 2001 From: Lucas Faudman <52257695+LucasFaudman@users.noreply.github.com> Date: Thu, 9 May 2024 05:38:32 -0700 Subject: [PATCH] Add new CodeMod CombineIsinstanceIssubclass. isistance(x, str) or isnstance(x, (bytes, list)) -> isinstance(x, (str, bytes, list)) (#494) * Add new CodeMod CombineIsinstanceIssubclass. isistance(x, str) or isinstance(x, (bytes, list)) -> isinstance(x, (str, bytes, list)) * refactor in same way as clavedeluna requests for combine_startswith_endswith, add docs Metadata & .md file, register codemod in __init__ * abstract logic from CombineStartswithEndswith and CombineIsinstanceIssubclass into CombineCallsBaseCodemod and make each a subclass, remove unused extract_boolean_operands from utils.py, add more params to test_no_change and test_mixed_boolean_operation for each * remove startswith/endswith comments per @clavedeluna request --------- Co-authored-by: Lucas Faudman --- .../test_combine_isinstance_issubclass.py | 16 ++ src/codemodder/codemods/utils.py | 19 +- src/codemodder/scripts/generate_docs.py | 4 + src/core_codemods/__init__.py | 2 + src/core_codemods/combine_calls_base.py | 132 ++++++++++++ .../combine_isinstance_issubclass.py | 32 +++ .../combine_startswith_endswith.py | 116 +++------- ...ee_python_combine-isinstance-issubclass.md | 12 ++ .../test_combine_isinstance_issubclass.py | 202 ++++++++++++++++++ .../test_combine_startswith_endswith.py | 66 ++++++ 10 files changed, 492 insertions(+), 109 deletions(-) create mode 100644 integration_tests/test_combine_isinstance_issubclass.py create mode 100644 src/core_codemods/combine_calls_base.py create mode 100644 src/core_codemods/combine_isinstance_issubclass.py create mode 100644 src/core_codemods/docs/pixee_python_combine-isinstance-issubclass.md create mode 100644 tests/codemods/test_combine_isinstance_issubclass.py diff --git a/integration_tests/test_combine_isinstance_issubclass.py b/integration_tests/test_combine_isinstance_issubclass.py new file mode 100644 index 00000000..7ce94b06 --- /dev/null +++ b/integration_tests/test_combine_isinstance_issubclass.py @@ -0,0 +1,16 @@ +from codemodder.codemods.test import BaseIntegrationTest +from core_codemods.combine_isinstance_issubclass import CombineIsinstanceIssubclass + + +class TestCombineStartswithEndswith(BaseIntegrationTest): + codemod = CombineIsinstanceIssubclass + original_code = """ + x = 'foo' + if isinstance(x, str) or isinstance(x, bytes): + print("Yes") + """ + replacement_lines = [(2, "if isinstance(x, (str, bytes)):\n")] + + expected_diff = "--- \n+++ \n@@ -1,3 +1,3 @@\n x = 'foo'\n-if isinstance(x, str) or isinstance(x, bytes):\n+if isinstance(x, (str, bytes)):\n print(\"Yes\")\n" + expected_line_change = "2" + change_description = CombineIsinstanceIssubclass.change_description diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index 67b3e015..862f362e 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Any, Iterator, Optional, Type, TypeAlias, TypeVar +from typing import Any, Optional, TypeAlias import libcst as cst from libcst import MetadataDependent, matchers @@ -216,20 +216,3 @@ def is_zero(node: cst.CSTNode) -> bool: case cst.Call(func=cst.Name("int")) | cst.Call(func=cst.Name("float")): return is_zero(node.args[0].value) return False - - -_CSTNode = TypeVar("_CSTNode", bound=cst.CSTNode) - - -def extract_boolean_operands( - node: cst.BooleanOperation, ensure_type: Type[_CSTNode] = cst.CSTNode -) -> Iterator[_CSTNode]: - """ - Recursively extract operands from a cst.BooleanOperation node from left to right as an iterator of nodes. - """ - if isinstance(node.left, cst.BooleanOperation): - yield from extract_boolean_operands(node.left, ensure_type) - else: - yield cst.ensure_type(node.left, ensure_type) - - yield cst.ensure_type(node.right, ensure_type) diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index 0b749dc3..0124f20e 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -199,6 +199,10 @@ class DocMetadata: importance="Low", guidance_explained="Simplifying expressions involving `startswith` or `endswith` calls is safe.", ), + "combine-isinstance-issubclass": DocMetadata( + importance="Low", + guidance_explained="Simplifying expressions involving `isinstance` or `issubclass` calls is safe.", + ), "fix-deprecated-logging-warn": DocMetadata( importance="Low", guidance_explained="This change fixes deprecated uses and is safe.", diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 6615f80f..a2e2e10f 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -2,6 +2,7 @@ from .add_requests_timeouts import AddRequestsTimeouts from .break_or_continue_out_of_loop import BreakOrContinueOutOfLoop +from .combine_isinstance_issubclass import CombineIsinstanceIssubclass from .combine_startswith_endswith import CombineStartswithEndswith from .defectdojo.semgrep.avoid_insecure_deserialization import ( AvoidInsecureDeserialization, @@ -135,6 +136,7 @@ RemoveModuleGlobal, RemoveDebugBreakpoint, CombineStartswithEndswith, + CombineIsinstanceIssubclass, FixDeprecatedLoggingWarn, FlaskEnableCSRFProtection, ReplaceFlaskSendFile, diff --git a/src/core_codemods/combine_calls_base.py b/src/core_codemods/combine_calls_base.py new file mode 100644 index 00000000..bb8376ea --- /dev/null +++ b/src/core_codemods/combine_calls_base.py @@ -0,0 +1,132 @@ +import libcst as cst +from libcst import matchers as m + +from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api import SimpleCodemod + + +class CombineCallsBaseCodemod(SimpleCodemod, NameResolutionMixin): + combinable_funcs: list[str] = [] + dedupilcation_attr: str = "value" + args_to_combine: list[int] = [0] + args_to_keep_as_is: list[int] = [] + + def leave_BooleanOperation( + self, original_node: cst.BooleanOperation, updated_node: cst.BooleanOperation + ) -> cst.CSTNode: + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return updated_node + + for call_matcher in map(self.make_call_matcher, self.combinable_funcs): + if self.matches_call_or_call(updated_node, call_matcher): + self.report_change(original_node) + return self.combine_calls(updated_node.left, updated_node.right) + + if self.matches_call_or_boolop(updated_node, call_matcher): + self.report_change(original_node) + return self.combine_call_or_boolop_fold_right(updated_node) + + if self.matches_boolop_or_call(updated_node, call_matcher): + self.report_change(original_node) + return self.combine_boolop_or_call_fold_left(updated_node) + + return updated_node + + def make_call_matcher(self, func_name: str) -> m.Call: + raise NotImplementedError("Subclasses must implement this method") + + def check_calls_same_instance( + self, left_call: cst.Call, right_call: cst.Call + ) -> bool: + raise NotImplementedError("Subclasses must implement this method") + + def matches_call_or_call( + self, node: cst.BooleanOperation, call_matcher: m.Call + ) -> bool: + call_or_call = m.BooleanOperation( + left=call_matcher, operator=m.Or(), right=call_matcher + ) + # True if the node matches the pattern and the calls are the same instance + return m.matches(node, call_or_call) and self.check_calls_same_instance( + node.left, node.right + ) + + def matches_call_or_boolop( + self, node: cst.BooleanOperation, call_matcher: m.Call + ) -> bool: + call_or_boolop = m.BooleanOperation( + left=call_matcher, + operator=m.Or(), + right=m.BooleanOperation(left=call_matcher), + ) + # True if the node matches the pattern and the calls are the same instance + return m.matches(node, call_or_boolop) and self.check_calls_same_instance( + node.left, node.right.left + ) + + def matches_boolop_or_call( + self, node: cst.BooleanOperation, call_matcher: m.Call + ) -> bool: + boolop_or_call = m.BooleanOperation( + left=m.BooleanOperation(right=call_matcher), + operator=m.Or(), + right=call_matcher, + ) + # True if the node matches the pattern and the calls are the same instance + return m.matches(node, boolop_or_call) and self.check_calls_same_instance( + node.left.right, node.right + ) + + def combine_calls(self, *calls: cst.Call) -> cst.Call: + first_call = calls[0] + new_args = [] + for arg_index in sorted(self.args_to_keep_as_is + self.args_to_combine): + if arg_index in self.args_to_combine: + new_args.append(self.combine_args(*calls, arg_index=arg_index)) + else: + new_args.append(first_call.args[arg_index]) + + return cst.Call(func=first_call.func, args=new_args) + + def combine_args(self, *calls: cst.Call, arg_index: int) -> cst.Arg: + elements = [] + seen_values = set() + for call in calls: + arg_value = call.args[arg_index].value + arg_elements = ( + arg_value.elements + if isinstance(arg_value, cst.Tuple) + else (cst.Element(value=arg_value),) + ) + + for element in arg_elements: + if ( + value := getattr(element.value, self.dedupilcation_attr, None) + ) in seen_values: + # If an element has a non-None value that has already been seen, continue to avoid duplicates + continue + if value is not None: + seen_values.add(value) + elements.append(element) + + return cst.Arg(value=cst.Tuple(elements=elements)) + + def combine_call_or_boolop_fold_right( + self, node: cst.BooleanOperation + ) -> cst.BooleanOperation: + new_left = self.combine_calls(node.left, node.right.left) + new_right = node.right.right + return cst.BooleanOperation( + left=new_left, operator=node.right.operator, right=new_right + ) + + def combine_boolop_or_call_fold_left( + self, node: cst.BooleanOperation + ) -> cst.BooleanOperation: + new_left = node.left.left + new_right = self.combine_calls(node.left.right, node.right) + return cst.BooleanOperation( + left=new_left, operator=node.left.operator, right=new_right + ) diff --git a/src/core_codemods/combine_isinstance_issubclass.py b/src/core_codemods/combine_isinstance_issubclass.py new file mode 100644 index 00000000..3885fe06 --- /dev/null +++ b/src/core_codemods/combine_isinstance_issubclass.py @@ -0,0 +1,32 @@ +import libcst as cst +from libcst import matchers as m + +from core_codemods.api import Metadata, ReviewGuidance + +from .combine_calls_base import CombineCallsBaseCodemod + + +class CombineIsinstanceIssubclass(CombineCallsBaseCodemod): + metadata = Metadata( + name="combine-isinstance-issubclass", + summary="Simplify Boolean Expressions Using `isinstance` and `issubclass`", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ) + change_description = "Use tuple of matches instead of boolean expression with `isinstance` or `issubclass`" + + combinable_funcs = ["isinstance", "issubclass"] + dedupilcation_attr = "value" + args_to_combine = [1] + args_to_keep_as_is = [0] + + def make_call_matcher(self, func_name: str) -> m.Call: + return m.Call( + func=m.Name(func_name), + args=[m.Arg(value=m.Name()), m.Arg(value=m.Name() | m.Tuple())], + ) + + def check_calls_same_instance( + self, left_call: cst.Call, right_call: cst.Call + ) -> bool: + return left_call.args[0].value.value == right_call.args[0].value.value diff --git a/src/core_codemods/combine_startswith_endswith.py b/src/core_codemods/combine_startswith_endswith.py index 4f49f507..52e24471 100644 --- a/src/core_codemods/combine_startswith_endswith.py +++ b/src/core_codemods/combine_startswith_endswith.py @@ -1,12 +1,12 @@ import libcst as cst from libcst import matchers as m -from codemodder.codemods.utils import extract_boolean_operands -from codemodder.codemods.utils_mixin import NameResolutionMixin -from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod +from core_codemods.api import Metadata, ReviewGuidance +from .combine_calls_base import CombineCallsBaseCodemod -class CombineStartswithEndswith(SimpleCodemod, NameResolutionMixin): + +class CombineStartswithEndswith(CombineCallsBaseCodemod): metadata = Metadata( name="combine-startswith-endswith", summary="Simplify Boolean Expressions Using `startswith` and `endswith`", @@ -15,92 +15,26 @@ class CombineStartswithEndswith(SimpleCodemod, NameResolutionMixin): ) change_description = "Use tuple of matches instead of boolean expression" - def leave_BooleanOperation( - self, original_node: cst.BooleanOperation, updated_node: cst.BooleanOperation - ) -> cst.CSTNode: - if not self.filter_by_path_includes_or_excludes( - self.node_position(original_node) - ): - return updated_node - - if self.matches_startswith_endswith_or_pattern(original_node): - self.report_change(original_node) - return self.make_new_call_from_boolean_operation(updated_node) - - return updated_node - - def matches_startswith_endswith_or_pattern( - self, node: cst.BooleanOperation - ) -> bool: - # Match the pattern: x.startswith("...") or x.startswith("...") or x.startswith("...") or ... - # and the same but with endswith - args = [ - m.Arg( - value=m.Tuple() - | m.SimpleString() - | m.ConcatenatedString() - | m.FormattedString() - | m.Name() - ) - ] - startswith = m.Call( - func=m.Attribute(value=m.Name(), attr=m.Name("startswith")), - args=args, - ) - endswith = m.Call( - func=m.Attribute(value=m.Name(), attr=m.Name("endswith")), - args=args, - ) - startswith_or = m.BooleanOperation( - left=startswith, operator=m.Or(), right=startswith + combinable_funcs = ["startswith", "endswith"] + dedupilcation_attr = "evaluated_value" + args_to_combine = [0] + args_to_keep_as_is = [] + + def make_call_matcher(self, func_name: str) -> m.Call: + return m.Call( + func=m.Attribute(value=m.Name(), attr=m.Name(func_name)), + args=[ + m.Arg( + value=m.Tuple() + | m.SimpleString() + | m.ConcatenatedString() + | m.FormattedString() + | m.Name() + ) + ], ) - endswith_or = m.BooleanOperation(left=endswith, operator=m.Or(), right=endswith) - # Check for simple case: x.startswith("...") or x.startswith("...") - if ( - m.matches(node, startswith_or | endswith_or) - and node.left.func.value.value == node.right.func.value.value - ): - return True - - # Check for chained case: x.startswith("...") or x.startswith("...") or x.startswith("...") or ... - if m.matches( - node, - m.BooleanOperation( - left=m.BooleanOperation(operator=m.Or()), - operator=m.Or(), - right=startswith | endswith, - ), - ): - return all( - call.func.value.value == node.right.func.value.value # Same function - for call in extract_boolean_operands(node, ensure_type=cst.Call) - ) - - return False - - def make_new_call_from_boolean_operation( - self, updated_node: cst.BooleanOperation - ) -> cst.Call: - elements = [] - seen_evaluated_values = set() - for call in extract_boolean_operands(updated_node, ensure_type=cst.Call): - arg_value = call.args[0].value - arg_elements = ( - arg_value.elements - if isinstance(arg_value, cst.Tuple) - else (cst.Element(value=arg_value),) - ) - - for element in arg_elements: - if ( - evaluated_value := getattr(element.value, "evaluated_value", None) - ) in seen_evaluated_values: - # If an element has a non-None evaluated value that has already been seen, continue to avoid duplicates - continue - if evaluated_value is not None: - seen_evaluated_values.add(evaluated_value) - elements.append(element) - - new_arg = cst.Arg(value=cst.Tuple(elements=elements)) - return cst.Call(func=call.func, args=[new_arg]) + def check_calls_same_instance( + self, left_call: cst.Call, right_call: cst.Call + ) -> bool: + return left_call.func.value.value == right_call.func.value.value diff --git a/src/core_codemods/docs/pixee_python_combine-isinstance-issubclass.md b/src/core_codemods/docs/pixee_python_combine-isinstance-issubclass.md new file mode 100644 index 00000000..d94504ae --- /dev/null +++ b/src/core_codemods/docs/pixee_python_combine-isinstance-issubclass.md @@ -0,0 +1,12 @@ +Many developers are not necessarily aware that the `isinstance` and `issubclass` builtin methods can accept a tuple of classes to match. This means that there is a lot of code that uses boolean expressions such as `isinstance(x, str) or isinstance(x, bytes)` instead of the simpler expression `isinstance(x, (str, bytes))`. + +This codemod simplifies the boolean expressions where possible which leads to cleaner and more concise code. + +The changes from this codemod look like this: + +```diff + x = 'foo' +- if isinstance(x, str) or isinstance(x, bytes): ++ if isinstance(x, (str, bytes)): + ... +``` diff --git a/tests/codemods/test_combine_isinstance_issubclass.py b/tests/codemods/test_combine_isinstance_issubclass.py new file mode 100644 index 00000000..3a158a9c --- /dev/null +++ b/tests/codemods/test_combine_isinstance_issubclass.py @@ -0,0 +1,202 @@ +import pytest + +from codemodder.codemods.test import BaseCodemodTest +from core_codemods.combine_isinstance_issubclass import CombineIsinstanceIssubclass + +each_func = pytest.mark.parametrize("func", ["isinstance", "issubclass"]) + + +class TestCombineIsinstanceIssubclass(BaseCodemodTest): + codemod = CombineIsinstanceIssubclass + + def test_name(self): + assert self.codemod.name == "combine-isinstance-issubclass" + + @each_func + def test_combine(self, tmpdir, func): + input_code = f""" + {func}(x, str) or {func}(x, bytes) + """ + expected = f""" + {func}(x, (str, bytes)) + """ + self.run_and_assert(tmpdir, input_code, expected) + + @pytest.mark.parametrize( + "code", + [ + "isinstance(x, str)", + "isinstance(x, (str, bytes))", + "isinstance(x, str) and isinstance(x, bytes)", + "isinstance(x, str) and isinstance(x, str) or True", + "isinstance(x, str) or issubclass(x, str)", + "isinstance(x, str) or isinstance(y, str)", + "isinstance(x, str) or isinstance(y, bytes) or isinstance(x, bytes)", + "isinstance(x, str) or isinstance(y, bytes) or isinstance(x, bytes) or isinstance(y, bytes)", + "isinstance(x, str) or issubclass(x, str) or isinstance(x, bytes) or issubclass(x, bytes) or isinstance(x, str)", + "isinstance(x, str) and isinstance(x, bytes) or isinstance(y, bytes)", + ], + ) + def test_no_change(self, tmpdir, code): + self.run_and_assert(tmpdir, code, code) + + def test_exclude_line(self, tmpdir): + input_code = ( + expected + ) = """ + x = "foo" + isinstance(x, str) or isinstance(x, bytes) + """ + lines_to_exclude = [3] + self.run_and_assert( + tmpdir, + input_code, + expected, + lines_to_exclude=lines_to_exclude, + ) + + def _format_func_run_test(self, tmpdir, func, input_code, expected, num_changes=1): + self.run_and_assert( + tmpdir, + input_code.replace("{func}", func), + expected.replace("{func}", func), + num_changes, + ) + + @each_func + @pytest.mark.parametrize( + "input_code, expected", + [ + # Tuple on the left + ( + "{func}(x, (str, bytes)) or {func}(x, bytearray)", + "{func}(x, (str, bytes, bytearray))", + ), + # Tuple on the right + ( + "{func}(x, str) or {func}(x, (bytes, bytearray))", + "{func}(x, (str, bytes, bytearray))", + ), + # Tuple on both sides no duplicates + ( + "{func}(x, (str, bytes)) or {func}(x, (bytearray, memoryview))", + "{func}(x, (str, bytes, bytearray, memoryview))", + ), + # Tuple on both sides with duplicates + ( + "{func}(x, (str, bytes)) or {func}(x, (str, bytes, bytearray))", + "{func}(x, (str, bytes, bytearray))", + ), + ], + ) + def test_combine_tuples(self, tmpdir, func, input_code, expected): + self._format_func_run_test(tmpdir, func, input_code, expected) + + @each_func + @pytest.mark.parametrize( + "input_code, expected", + [ + # 3 cst.Names + ( + "{func}(x, str) or {func}(x, bytes) or {func}(x, bytearray)", + "{func}(x, (str, bytes, bytearray))", + ), + # 4 cst.Names + ( + "{func}(x, str) or {func}(x, bytes) or {func}(x, bytearray) or {func}(x, some_type)", + "{func}(x, (str, bytes, bytearray, some_type))", + ), + # 5 cst.Names + ( + "{func}(x, str) or {func}(x, bytes) or {func}(x, bytearray) or {func}(x, some_type) or {func}(x, another_type)", + "{func}(x, (str, bytes, bytearray, some_type, another_type))", + ), + # 2 cst.Names and 1 cst.Tuple + ( + "{func}(x, str) or {func}(x, bytes) or {func}(x, (bytearray, memoryview))", + "{func}(x, (str, bytes, bytearray, memoryview))", + ), + # 2 cst.Name and 2 cst.Tuples + ( + "{func}(x, str) or {func}(x, (bytes, bytearray)) or {func}(x, (memoryview, bytearray)) or {func}(x, list)", + "{func}(x, (str, bytes, bytearray, memoryview, list))", + ), + # 3 cst.Tuples + ( + "{func}(x, (str, bytes)) or {func}(x, (bytes, bytearray)) or {func}(x, (bytearray, memoryview))", + "{func}(x, (str, bytes, bytearray, memoryview))", + ), + # 4 cst.Tuples + ( + "{func}(x, (str, bytes)) or {func}(x, (bytes, bytearray)) or {func}(x, (bytearray, memoryview)) or {func}(x, (memoryview, str))", + "{func}(x, (str, bytes, bytearray, memoryview))", + ), + ], + ) + def test_more_than_two_calls(self, tmpdir, func, input_code, expected): + self._format_func_run_test( + tmpdir, func, input_code, expected, input_code.count(" or ") + ) + + @each_func + @pytest.mark.parametrize( + "input_code, expected", + [ + # same name and/or on left + ( + "x and {func}(x, str) or {func}(x, bytes)", + "x and {func}(x, (str, bytes))", + ), + ( + "x or {func}(x, str) or {func}(x, bytes)", + "x or {func}(x, (str, bytes))", + ), + # same name and/or on right + ( + "{func}(x, str) or {func}(x, bytes) and x", + "{func}(x, (str, bytes)) and x", + ), + ( + "{func}(x, str) or {func}(x, bytes) or x", + "{func}(x, (str, bytes)) or x", + ), + # other name and/or on left + ( + "y and {func}(x, str) or {func}(x, bytes)", + "y and {func}(x, (str, bytes))", + ), + ( + "y or {func}(x, str) or {func}(x, bytes)", + "y or {func}(x, (str, bytes))", + ), + # other name and/or on right + ( + "{func}(x, str) or {func}(x, bytes) and y", + "{func}(x, (str, bytes)) and y", + ), + ( + "{func}(x, str) or {func}(x, bytes) or y", + "{func}(x, (str, bytes)) or y", + ), + # same name and/or on left, other name and/or on right + ( + "x or {func}(x, str) or {func}(x, bytes) or y", + "x or {func}(x, (str, bytes)) or y", + ), + ( + "x and {func}(x, str) or {func}(x, bytes) or y", + "x and {func}(x, (str, bytes)) or y", + ), + # other name and/or on left, same name and/or on right + ( + "y or {func}(x, str) or {func}(x, bytes) or x", + "y or {func}(x, (str, bytes)) or x", + ), + ( + "y and {func}(x, str) or {func}(x, bytes) or x", + "y and {func}(x, (str, bytes)) or x", + ), + ], + ) + def test_mixed_boolean_operation(self, tmpdir, func, input_code, expected): + self._format_func_run_test(tmpdir, func, input_code, expected) diff --git a/tests/codemods/test_combine_startswith_endswith.py b/tests/codemods/test_combine_startswith_endswith.py index 944d821b..f3ed2808 100644 --- a/tests/codemods/test_combine_startswith_endswith.py +++ b/tests/codemods/test_combine_startswith_endswith.py @@ -34,6 +34,9 @@ def test_combine(self, tmpdir, func): "x.startswith('foo') or x.endswith('f')", "x.startswith('foo') or y.startswith('f')", "x.startswith('foo') or y.startswith('f') or x.startswith('f')", + "x.startswith('foo') or y.startswith('f') or x.startswith('f') or y.startswith('f')", + "x.startswith('foo') or x.endswith('foo') or x.startswith('bar') or x.endswith('bar') or x.startswith('foo')", + "x.startswith('foo') and x.startswith('f') or y.startswith('bar')", ], ) def test_no_change(self, tmpdir, code): @@ -220,3 +223,66 @@ def test_more_than_two_calls(self, tmpdir, func, input_code, expected): self._format_func_run_test( tmpdir, func, input_code, expected, num_changes=input_code.count(" or ") ) + + @each_func + @pytest.mark.parametrize( + "input_code, expected", + [ + # same name and/or on left + ( + "x and x.{func}('foo') or x.{func}('bar')", + "x and x.{func}(('foo', 'bar'))", + ), + ( + "x or x.{func}('foo') or x.{func}('bar')", + "x or x.{func}(('foo', 'bar'))", + ), + # same name and/or on right + ( + "x.{func}('foo') or x.{func}('bar') and x", + "x.{func}(('foo', 'bar')) and x", + ), + ( + "x.{func}('foo') or x.{func}('bar') or x", + "x.{func}(('foo', 'bar')) or x", + ), + # other name and/or on left + ( + "y and x.{func}('foo') or x.{func}('bar')", + "y and x.{func}(('foo', 'bar'))", + ), + ( + "y or x.{func}('foo') or x.{func}('bar')", + "y or x.{func}(('foo', 'bar'))", + ), + # other name and/or on right + ( + "x.{func}('foo') or x.{func}('bar') and y", + "x.{func}(('foo', 'bar')) and y", + ), + ( + "x.{func}('foo') or x.{func}('bar') or y", + "x.{func}(('foo', 'bar')) or y", + ), + # same name and/or on left, other name and/or on right + ( + "x or x.{func}('foo') or x.{func}('bar') or y", + "x or x.{func}(('foo', 'bar')) or y", + ), + ( + "x and x.{func}('foo') or x.{func}('bar') or y", + "x and x.{func}(('foo', 'bar')) or y", + ), + # other name and/or on left, same name and/or on right + ( + "y or x.{func}('foo') or x.{func}('bar') or x", + "y or x.{func}(('foo', 'bar')) or x", + ), + ( + "y and x.{func}('foo') or x.{func}('bar') or x", + "y and x.{func}(('foo', 'bar')) or x", + ), + ], + ) + def test_mixed_boolean_operation(self, tmpdir, func, input_code, expected): + self._format_func_run_test(tmpdir, func, input_code, expected)