From 02a463c3b1421593cdc49230ad876d8b99eb16a6 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 12 Feb 2024 10:56:58 -0300 Subject: [PATCH] Added support for statement suites in fix-mutable-params --- integration_tests/test_fix_mutable_params.py | 7 +- src/core_codemods/fix_mutable_params.py | 125 ++++++++++++------- tests/codemods/test_fix_mutable_params.py | 36 ++++++ 3 files changed, 123 insertions(+), 45 deletions(-) diff --git a/integration_tests/test_fix_mutable_params.py b/integration_tests/test_fix_mutable_params.py index b84470b0..d9e50038 100644 --- a/integration_tests/test_fix_mutable_params.py +++ b/integration_tests/test_fix_mutable_params.py @@ -1,4 +1,7 @@ -from core_codemods.fix_mutable_params import FixMutableParams +from core_codemods.fix_mutable_params import ( + FixMutableParams, + FixMutableParamsTransformer, +) from integration_tests.base_test import ( BaseIntegrationTest, original_and_expected_from_code_path, @@ -30,4 +33,4 @@ def baz(x=None, y=None): expected_diff = '--- \n+++ \n@@ -1,4 +1,5 @@\n-def foo(x, y=[]):\n+def foo(x, y=None):\n+ y = [] if y is None else y\n y.append(x)\n print(y)\n \n@@ -7,6 +8,8 @@\n print(x)\n \n \n-def baz(x={"foo": 42}, y=set()):\n+def baz(x=None, y=None):\n+ x = {"foo": 42} if x is None else x\n+ y = set() if y is None else y\n print(x)\n print(y)\n' expected_line_change = 1 num_changes = 2 - change_description = FixMutableParams.change_description + change_description = FixMutableParamsTransformer.change_description diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index 838b4b40..66e0afc3 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -1,15 +1,11 @@ import libcst as cst from libcst import matchers as m +from codemodder.codemods.libcst_transformer import LibcstTransformerPipeline from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod +from core_codemods.api.core_codemod import CoreCodemod -class FixMutableParams(SimpleCodemod): - metadata = Metadata( - name="fix-mutable-params", - summary="Replace Mutable Default Parameters", - review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, - references=[], - ) +class FixMutableParamsTransformer(SimpleCodemod): change_description = "Replace mutable parameter with `None`." _BUILTIN_TO_LITERAL = { @@ -101,46 +97,51 @@ def _gather_and_update_params( return updated_params, new_var_decls, add_annotation - def _build_body_prefix(self, new_var_decls: list[cst.Param]): + def _build_body_prefix(self, new_var_decls: list[cst.Param]) -> list[cst.Assign]: return [ - cst.SimpleStatementLine( - body=[ - cst.Assign( - targets=[cst.AssignTarget(target=var_decl.name)], - value=cst.IfExp( - test=cst.Comparison( - left=var_decl.name, - comparisons=[ - cst.ComparisonTarget(cst.Is(), cst.Name("None")) - ], - ), - # In the case of list() or dict(), this particular - # default value has been updated to use the literal - # instead. This does not affect the default - # argument in the function itself. - body=var_decl.default, - orelse=var_decl.name, - ), - ) - ] + cst.Assign( + targets=[cst.AssignTarget(target=var_decl.name)], + value=cst.IfExp( + test=cst.Comparison( + left=var_decl.name, + comparisons=[cst.ComparisonTarget(cst.Is(), cst.Name("None"))], + ), + # In the case of list() or dict(), this particular + # default value has been updated to use the literal + # instead. This does not affect the default + # argument in the function itself. + body=var_decl.default, + orelse=var_decl.name, + ), ) for var_decl in new_var_decls ] - def _build_new_body(self, new_var_decls, body): + def _build_new_body( + self, new_var_decls, body: cst.BaseSuite + ) -> list[cst.BaseStatement] | list[cst.BaseSmallStatement]: offset = 0 new_body = [] - # Preserve placement of docstring - if body and m.matches( - body[0], - m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), + if m.matches( + body.body[0], + m.Expr(value=m.SimpleString()) + | m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), ): - new_body.append(body[0]) + new_body.append(body.body[0]) offset = 1 - - new_body.extend(self._build_body_prefix(new_var_decls)) - new_body.extend(body[offset:]) + match body: + case cst.SimpleStatementSuite(): + new_body.extend(self._build_body_prefix(new_var_decls)) + new_body.extend(body.body[offset:]) + case cst.IndentedBlock(): + new_body.extend( + [ + cst.SimpleStatementLine(body=[stmt]) + for stmt in self._build_body_prefix(new_var_decls) + ] + ) + new_body.extend(body.body[offset:]) return new_body def _is_abstractmethod(self, node: cst.FunctionDef) -> bool: @@ -151,6 +152,14 @@ def _is_abstractmethod(self, node: cst.FunctionDef) -> bool: return False + def _is_overloaded(self, node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + match decorator.decorator: + case cst.Name("overload"): + return True + + return False + def leave_FunctionDef( self, original_node: cst.FunctionDef, @@ -159,23 +168,53 @@ def leave_FunctionDef( """Transforms function definitions with mutable default parameters""" # TODO: add filter by include or exclude that works for nodes # that that have different start/end numbers. + ( updated_params, new_var_decls, add_annotation, ) = self._gather_and_update_params(original_node, updated_node) - new_body = ( - self._build_new_body(new_var_decls, updated_node.body.body) - if not self._is_abstractmethod(original_node) - else updated_node.body.body - ) + if new_var_decls: # If we're adding statements to the body, we know a change took place self.add_change(original_node, self.change_description) if add_annotation: self.add_needed_import("typing", "Optional") + # overloaded methods with empty bodies should only change signature + empty_statement = m.Expr(value=m.Ellipsis()) | m.Pass() + if self._is_overloaded(updated_node) and m.matches( + original_node.body, + m.SimpleStatementSuite(body=[empty_statement]) + | m.IndentedBlock(body=[m.SimpleStatementLine(body=[empty_statement])]), + ): + return updated_node.with_changes( + params=updated_node.params.with_changes(params=updated_params) + ) + + new_body = ( + self._build_new_body(new_var_decls, updated_node.body) + if not self._is_abstractmethod(original_node) + else updated_node.body.body + ) + return updated_node.with_changes( params=updated_node.params.with_changes(params=updated_params), - body=updated_node.body.with_changes(body=new_body), + body=( + updated_node.body.with_changes(body=new_body) + if new_body + else updated_node.body + ), ) + + +FixMutableParams = CoreCodemod( + metadata=Metadata( + name="fix-mutable-params", + summary="Replace Mutable Default Parameters", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[], + ), + transformer=LibcstTransformerPipeline(FixMutableParamsTransformer), + detector=None, +) diff --git a/tests/codemods/test_fix_mutable_params.py b/tests/codemods/test_fix_mutable_params.py index 8145bd3f..1e1c20c1 100644 --- a/tests/codemods/test_fix_mutable_params.py +++ b/tests/codemods/test_fix_mutable_params.py @@ -21,6 +21,16 @@ def foo(bar=None): """ self.run_and_assert(tmpdir, input_code, expected_output) + @pytest.mark.parametrize("mutable", ["[]", "{}", "set()"]) + def test_fix_single_arg_suite(self, tmpdir, mutable): + input_code = f""" + def foo(bar={mutable}): print(bar) + """ + expected_output = f""" + def foo(bar=None): bar = {mutable} if bar is None else bar; print(bar) + """ + self.run_and_assert(tmpdir, input_code, expected_output) + @pytest.mark.parametrize("mutable", ["[]", "{}", "set()"]) def test_fix_single_arg_method(self, tmpdir, mutable): input_code = f""" @@ -90,6 +100,32 @@ def foo(bar=None, baz=None): """ self.run_and_assert(tmpdir, input_code, expected_output) + def test_fix_overloaded(self, tmpdir): + input_code = """ + from typing import overload + @overload + def foo(a : list[str] = []) -> str: + ... + @overload + def foo(a : list[int] = []) -> int: + ... + def foo(a : list[int] | list[str] = []) -> int|str: + return 0 + """ + expected_output = """ + from typing import Optional, overload + @overload + def foo(a : Optional[list[str]] = None) -> str: + ... + @overload + def foo(a : Optional[list[int]] = None) -> int: + ... + def foo(a : Optional[list[int] | list[str]] = None) -> int|str: + a = [] if a is None else a + return 0 + """ + self.run_and_assert(tmpdir, input_code, expected_output, num_changes=3) + def test_fix_multiple_args_mixed(self, tmpdir): input_code = """ def foo(bar=[], x="hello", baz={"foo": 42}, biz=set(), boz=list(), buz={1, 2, 3}):