From 6347944a41657ee4488ec3f22d16c73362c3ef31 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 12 Feb 2024 10:56:58 -0300 Subject: [PATCH] Added support for statement suites in fix-mutable-params --- src/core_codemods/fix_mutable_params.py | 104 ++++++++++++++-------- tests/codemods/test_fix_mutable_params.py | 36 ++++++++ 2 files changed, 104 insertions(+), 36 deletions(-) diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index ae991131..15cdfe51 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -11,6 +11,7 @@ class FixMutableParams(BaseCodemod): REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW DESCRIPTION = "Replace mutable parameter with `None`." REFERENCES: list = [] + _BUILTIN_TO_LITERAL = { "list": cst.List(elements=[]), "dict": cst.Dict(elements=[]), @@ -98,46 +99,51 @@ def _gather_and_update_params( return updated_params, new_var_decls, add_annotation - def _build_body_prefix(self, new_var_decls: list[cst.Param]): + def _build_body_prefix(self, new_var_decls: list[cst.Param]) -> list[cst.Assign]: return [ - cst.SimpleStatementLine( - body=[ - cst.Assign( - targets=[cst.AssignTarget(target=var_decl.name)], - value=cst.IfExp( - test=cst.Comparison( - left=var_decl.name, - comparisons=[ - cst.ComparisonTarget(cst.Is(), cst.Name("None")) - ], - ), - # In the case of list() or dict(), this particular - # default value has been updated to use the literal - # instead. This does not affect the default - # argument in the function itself. - body=var_decl.default, - orelse=var_decl.name, - ), - ) - ] + cst.Assign( + targets=[cst.AssignTarget(target=var_decl.name)], + value=cst.IfExp( + test=cst.Comparison( + left=var_decl.name, + comparisons=[cst.ComparisonTarget(cst.Is(), cst.Name("None"))], + ), + # In the case of list() or dict(), this particular + # default value has been updated to use the literal + # instead. This does not affect the default + # argument in the function itself. + body=var_decl.default, + orelse=var_decl.name, + ), ) for var_decl in new_var_decls ] - def _build_new_body(self, new_var_decls, body): + def _build_new_body( + self, new_var_decls, body: cst.BaseSuite + ) -> list[cst.BaseStatement] | list[cst.BaseSmallStatement]: offset = 0 new_body = [] - # Preserve placement of docstring - if body and m.matches( - body[0], - m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), + if m.matches( + body.body[0], + m.Expr(value=m.SimpleString()) + | m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), ): - new_body.append(body[0]) + new_body.append(body.body[0]) offset = 1 - - new_body.extend(self._build_body_prefix(new_var_decls)) - new_body.extend(body[offset:]) + match body: + case cst.SimpleStatementSuite(): + new_body.extend(self._build_body_prefix(new_var_decls)) + new_body.extend(body.body[offset:]) + case cst.IndentedBlock(): + new_body.extend( + [ + cst.SimpleStatementLine(body=[stmt]) + for stmt in self._build_body_prefix(new_var_decls) + ] + ) + new_body.extend(body.body[offset:]) return new_body def _is_abstractmethod(self, node: cst.FunctionDef) -> bool: @@ -148,6 +154,14 @@ def _is_abstractmethod(self, node: cst.FunctionDef) -> bool: return False + def _is_overloaded(self, node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + match decorator.decorator: + case cst.Name("overload"): + return True + + return False + def leave_FunctionDef( self, original_node: cst.FunctionDef, @@ -156,23 +170,41 @@ def leave_FunctionDef( """Transforms function definitions with mutable default parameters""" # TODO: add filter by include or exclude that works for nodes # that that have different start/end numbers. + ( updated_params, new_var_decls, add_annotation, ) = self._gather_and_update_params(original_node, updated_node) - new_body = ( - self._build_new_body(new_var_decls, updated_node.body.body) - if not self._is_abstractmethod(original_node) - else updated_node.body.body - ) + if new_var_decls: # If we're adding statements to the body, we know a change took place self.add_change(original_node, self.CHANGE_DESCRIPTION) if add_annotation: self.add_needed_import("typing", "Optional") + # overloaded methods with empty bodies should only change signature + empty_statement = m.Expr(value=m.Ellipsis()) | m.Pass() + if self._is_overloaded(updated_node) and m.matches( + original_node.body, + m.SimpleStatementSuite(body=[empty_statement]) + | m.IndentedBlock(body=[m.SimpleStatementLine(body=[empty_statement])]), + ): + return updated_node.with_changes( + params=updated_node.params.with_changes(params=updated_params) + ) + + new_body = ( + self._build_new_body(new_var_decls, updated_node.body) + if not self._is_abstractmethod(original_node) + else updated_node.body.body + ) + return updated_node.with_changes( params=updated_node.params.with_changes(params=updated_params), - body=updated_node.body.with_changes(body=new_body), + body=( + updated_node.body.with_changes(body=new_body) + if new_body + else updated_node.body + ), ) diff --git a/tests/codemods/test_fix_mutable_params.py b/tests/codemods/test_fix_mutable_params.py index 8145bd3f..6eb7c9ca 100644 --- a/tests/codemods/test_fix_mutable_params.py +++ b/tests/codemods/test_fix_mutable_params.py @@ -21,6 +21,16 @@ def foo(bar=None): """ self.run_and_assert(tmpdir, input_code, expected_output) + @pytest.mark.parametrize("mutable", ["[]", "{}", "set()"]) + def test_fix_single_arg_suite(self, tmpdir, mutable): + input_code = f""" + def foo(bar={mutable}): print(bar) + """ + expected_output = f""" + def foo(bar=None): bar = {mutable} if bar is None else bar; print(bar) + """ + self.run_and_assert(tmpdir, input_code, expected_output) + @pytest.mark.parametrize("mutable", ["[]", "{}", "set()"]) def test_fix_single_arg_method(self, tmpdir, mutable): input_code = f""" @@ -90,6 +100,32 @@ def foo(bar=None, baz=None): """ self.run_and_assert(tmpdir, input_code, expected_output) + def test_fix_overloaded(self, tmpdir): + input_code = """ + from typing import overload + @overload + def foo(a : list[str] = []) -> str: + ... + @overload + def foo(a : list[int] = []) -> int: + ... + def foo(a : list[int] | list[str] = []) -> int|str: + return 0 + """ + expected_output = """ + from typing import Optional, overload + @overload + def foo(a : Optional[list[str]] = None) -> str: + ... + @overload + def foo(a : Optional[list[int]] = None) -> int: + ... + def foo(a : Optional[list[int] | list[str]] = None) -> int|str: + a = [] if a is None else a + return 0 + """ + self.run_and_assert(tmpdir, input_code, expected_output) + def test_fix_multiple_args_mixed(self, tmpdir): input_code = """ def foo(bar=[], x="hello", baz={"foo": 42}, biz=set(), boz=list(), buz={1, 2, 3}):