Skip to content

Commit

Permalink
Added support for statement suites in fix-mutable-params
Browse files Browse the repository at this point in the history
  • Loading branch information
andrecsilva authored and drdavella committed Feb 13, 2024
1 parent c81d5ce commit 6347944
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 36 deletions.
104 changes: 68 additions & 36 deletions src/core_codemods/fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class FixMutableParams(BaseCodemod):
REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW
DESCRIPTION = "Replace mutable parameter with `None`."
REFERENCES: list = []

_BUILTIN_TO_LITERAL = {
"list": cst.List(elements=[]),
"dict": cst.Dict(elements=[]),
Expand Down Expand Up @@ -98,46 +99,51 @@ def _gather_and_update_params(

return updated_params, new_var_decls, add_annotation

def _build_body_prefix(self, new_var_decls: list[cst.Param]):
def _build_body_prefix(self, new_var_decls: list[cst.Param]) -> list[cst.Assign]:
return [
cst.SimpleStatementLine(
body=[
cst.Assign(
targets=[cst.AssignTarget(target=var_decl.name)],
value=cst.IfExp(
test=cst.Comparison(
left=var_decl.name,
comparisons=[
cst.ComparisonTarget(cst.Is(), cst.Name("None"))
],
),
# In the case of list() or dict(), this particular
# default value has been updated to use the literal
# instead. This does not affect the default
# argument in the function itself.
body=var_decl.default,
orelse=var_decl.name,
),
)
]
cst.Assign(
targets=[cst.AssignTarget(target=var_decl.name)],
value=cst.IfExp(
test=cst.Comparison(
left=var_decl.name,
comparisons=[cst.ComparisonTarget(cst.Is(), cst.Name("None"))],
),
# In the case of list() or dict(), this particular
# default value has been updated to use the literal
# instead. This does not affect the default
# argument in the function itself.
body=var_decl.default,
orelse=var_decl.name,
),
)
for var_decl in new_var_decls
]

def _build_new_body(self, new_var_decls, body):
def _build_new_body(
self, new_var_decls, body: cst.BaseSuite
) -> list[cst.BaseStatement] | list[cst.BaseSmallStatement]:
offset = 0
new_body = []

# Preserve placement of docstring
if body and m.matches(
body[0],
m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
if m.matches(
body.body[0],
m.Expr(value=m.SimpleString())
| m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
):
new_body.append(body[0])
new_body.append(body.body[0])
offset = 1

new_body.extend(self._build_body_prefix(new_var_decls))
new_body.extend(body[offset:])
match body:
case cst.SimpleStatementSuite():
new_body.extend(self._build_body_prefix(new_var_decls))
new_body.extend(body.body[offset:])
case cst.IndentedBlock():
new_body.extend(
[
cst.SimpleStatementLine(body=[stmt])
for stmt in self._build_body_prefix(new_var_decls)
]
)
new_body.extend(body.body[offset:])
return new_body

def _is_abstractmethod(self, node: cst.FunctionDef) -> bool:
Expand All @@ -148,6 +154,14 @@ def _is_abstractmethod(self, node: cst.FunctionDef) -> bool:

return False

def _is_overloaded(self, node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
match decorator.decorator:
case cst.Name("overload"):
return True

return False

def leave_FunctionDef(
self,
original_node: cst.FunctionDef,
Expand All @@ -156,23 +170,41 @@ def leave_FunctionDef(
"""Transforms function definitions with mutable default parameters"""
# TODO: add filter by include or exclude that works for nodes
# that that have different start/end numbers.

(
updated_params,
new_var_decls,
add_annotation,
) = self._gather_and_update_params(original_node, updated_node)
new_body = (
self._build_new_body(new_var_decls, updated_node.body.body)
if not self._is_abstractmethod(original_node)
else updated_node.body.body
)

if new_var_decls:
# If we're adding statements to the body, we know a change took place
self.add_change(original_node, self.CHANGE_DESCRIPTION)
if add_annotation:
self.add_needed_import("typing", "Optional")

# overloaded methods with empty bodies should only change signature
empty_statement = m.Expr(value=m.Ellipsis()) | m.Pass()
if self._is_overloaded(updated_node) and m.matches(
original_node.body,
m.SimpleStatementSuite(body=[empty_statement])
| m.IndentedBlock(body=[m.SimpleStatementLine(body=[empty_statement])]),
):
return updated_node.with_changes(
params=updated_node.params.with_changes(params=updated_params)
)

new_body = (
self._build_new_body(new_var_decls, updated_node.body)
if not self._is_abstractmethod(original_node)
else updated_node.body.body
)

return updated_node.with_changes(
params=updated_node.params.with_changes(params=updated_params),
body=updated_node.body.with_changes(body=new_body),
body=(
updated_node.body.with_changes(body=new_body)
if new_body
else updated_node.body
),
)
36 changes: 36 additions & 0 deletions tests/codemods/test_fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def foo(bar=None):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("mutable", ["[]", "{}", "set()"])
def test_fix_single_arg_suite(self, tmpdir, mutable):
input_code = f"""
def foo(bar={mutable}): print(bar)
"""
expected_output = f"""
def foo(bar=None): bar = {mutable} if bar is None else bar; print(bar)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("mutable", ["[]", "{}", "set()"])
def test_fix_single_arg_method(self, tmpdir, mutable):
input_code = f"""
Expand Down Expand Up @@ -90,6 +100,32 @@ def foo(bar=None, baz=None):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_overloaded(self, tmpdir):
input_code = """
from typing import overload
@overload
def foo(a : list[str] = []) -> str:
...
@overload
def foo(a : list[int] = []) -> int:
...
def foo(a : list[int] | list[str] = []) -> int|str:
return 0
"""
expected_output = """
from typing import Optional, overload
@overload
def foo(a : Optional[list[str]] = None) -> str:
...
@overload
def foo(a : Optional[list[int]] = None) -> int:
...
def foo(a : Optional[list[int] | list[str]] = None) -> int|str:
a = [] if a is None else a
return 0
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_multiple_args_mixed(self, tmpdir):
input_code = """
def foo(bar=[], x="hello", baz={"foo": 42}, biz=set(), boz=list(), buz={1, 2, 3}):
Expand Down

0 comments on commit 6347944

Please sign in to comment.