From 8ac9fc9b2e3ee95472f4306e834c5e4975b554de Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Mon, 2 Oct 2023 14:59:31 -0400 Subject: [PATCH] Respect placement of docstrings in fix-mutable-params --- src/core_codemods/fix_mutable_params.py | 21 ++++++++++++++--- tests/codemods/test_fix_mutable_params.py | 28 +++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/core_codemods/fix_mutable_params.py b/src/core_codemods/fix_mutable_params.py index 8c3eab10..bd5cd5cd 100644 --- a/src/core_codemods/fix_mutable_params.py +++ b/src/core_codemods/fix_mutable_params.py @@ -117,6 +117,22 @@ def _build_body_prefix(self, new_var_decls: list[cst.Param]): for var_decl in new_var_decls ] + def _build_new_body(self, new_var_decls, body): + offset = 0 + new_body = [] + + # Preserve placement of docstring + if body and m.matches( + body[0], + m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), + ): + new_body.append(body[0]) + offset = 1 + + new_body.extend(self._build_body_prefix(new_var_decls)) + new_body.extend(body[offset:]) + return new_body + def leave_FunctionDef( self, original_node: cst.FunctionDef, @@ -128,14 +144,13 @@ def leave_FunctionDef( new_var_decls, add_annotation, ) = self._gather_and_update_params(original_node, updated_node) - # Add any new variable declarations to the top of the function body - if body_prefix := self._build_body_prefix(new_var_decls): + new_body = self._build_new_body(new_var_decls, updated_node.body.body) + if new_var_decls: # If we're adding statements to the body, we know a change took place self.add_change(original_node, self.CHANGE_DESCRIPTION) if add_annotation: self.add_needed_import("typing", "Optional") - new_body = tuple(body_prefix) + updated_node.body.body return updated_node.with_changes( params=updated_node.params.with_changes(params=updated_params), body=updated_node.body.with_changes(body=new_body), diff --git a/tests/codemods/test_fix_mutable_params.py b/tests/codemods/test_fix_mutable_params.py index 6ceab71d..f8c6334f 100644 --- a/tests/codemods/test_fix_mutable_params.py +++ b/tests/codemods/test_fix_mutable_params.py @@ -191,5 +191,33 @@ def foo(x = None, y: Optional[List[int]] = None, z: Optional[Dict[str, int]] = N y = [] if y is None else y z = {} if z is None else z print(x, y, z) +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_fix_respect_docstring(self, tmpdir): + input_code = ''' +def func(foo=[]): + """Here is a docstring""" + print(foo) +''' + expected_output = ''' +def func(foo=None): + """Here is a docstring""" + foo = [] if foo is None else foo + print(foo) +''' + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_fix_respect_leading_comment(self, tmpdir): + input_code = """ +def func(foo=[]): + # Here is a comment + print(foo) +""" + expected_output = """ +def func(foo=None): + foo = [] if foo is None else foo + # Here is a comment + print(foo) """ self.run_and_assert(tmpdir, input_code, expected_output)