Skip to content

Commit

Permalink
Respect placement of docstrings in fix-mutable-params
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Oct 3, 2023
1 parent dfdbba0 commit 8ac9fc9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/core_codemods/fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def _build_body_prefix(self, new_var_decls: list[cst.Param]):
for var_decl in new_var_decls
]

def _build_new_body(self, new_var_decls, body):
offset = 0
new_body = []

# Preserve placement of docstring
if body and m.matches(
body[0],
m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
):
new_body.append(body[0])
offset = 1

new_body.extend(self._build_body_prefix(new_var_decls))
new_body.extend(body[offset:])
return new_body

def leave_FunctionDef(
self,
original_node: cst.FunctionDef,
Expand All @@ -128,14 +144,13 @@ def leave_FunctionDef(
new_var_decls,
add_annotation,
) = self._gather_and_update_params(original_node, updated_node)
# Add any new variable declarations to the top of the function body
if body_prefix := self._build_body_prefix(new_var_decls):
new_body = self._build_new_body(new_var_decls, updated_node.body.body)
if new_var_decls:
# If we're adding statements to the body, we know a change took place
self.add_change(original_node, self.CHANGE_DESCRIPTION)
if add_annotation:
self.add_needed_import("typing", "Optional")

new_body = tuple(body_prefix) + updated_node.body.body
return updated_node.with_changes(
params=updated_node.params.with_changes(params=updated_params),
body=updated_node.body.with_changes(body=new_body),
Expand Down
28 changes: 28 additions & 0 deletions tests/codemods/test_fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,33 @@ def foo(x = None, y: Optional[List[int]] = None, z: Optional[Dict[str, int]] = N
y = [] if y is None else y
z = {} if z is None else z
print(x, y, z)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_respect_docstring(self, tmpdir):
input_code = '''
def func(foo=[]):
"""Here is a docstring"""
print(foo)
'''
expected_output = '''
def func(foo=None):
"""Here is a docstring"""
foo = [] if foo is None else foo
print(foo)
'''
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_respect_leading_comment(self, tmpdir):
input_code = """
def func(foo=[]):
# Here is a comment
print(foo)
"""
expected_output = """
def func(foo=None):
foo = [] if foo is None else foo
# Here is a comment
print(foo)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

0 comments on commit 8ac9fc9

Please sign in to comment.