Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for statement suites and overload in fix-mutable-params #255

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions integration_tests/test_fix_mutable_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from core_codemods.fix_mutable_params import FixMutableParams
from core_codemods.fix_mutable_params import (
FixMutableParams,
FixMutableParamsTransformer,
)
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
Expand Down Expand Up @@ -30,4 +33,4 @@ def baz(x=None, y=None):
expected_diff = '--- \n+++ \n@@ -1,4 +1,5 @@\n-def foo(x, y=[]):\n+def foo(x, y=None):\n+ y = [] if y is None else y\n y.append(x)\n print(y)\n \n@@ -7,6 +8,8 @@\n print(x)\n \n \n-def baz(x={"foo": 42}, y=set()):\n+def baz(x=None, y=None):\n+ x = {"foo": 42} if x is None else x\n+ y = set() if y is None else y\n print(x)\n print(y)\n'
expected_line_change = 1
num_changes = 2
change_description = FixMutableParams.change_description
change_description = FixMutableParamsTransformer.change_description
125 changes: 82 additions & 43 deletions src/core_codemods/fix_mutable_params.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import libcst as cst
from libcst import matchers as m
from codemodder.codemods.libcst_transformer import LibcstTransformerPipeline
from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod
from core_codemods.api.core_codemod import CoreCodemod


class FixMutableParams(SimpleCodemod):
metadata = Metadata(
name="fix-mutable-params",
summary="Replace Mutable Default Parameters",
review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW,
references=[],
)
class FixMutableParamsTransformer(SimpleCodemod):
change_description = "Replace mutable parameter with `None`."

_BUILTIN_TO_LITERAL = {
Expand Down Expand Up @@ -101,46 +97,51 @@ def _gather_and_update_params(

return updated_params, new_var_decls, add_annotation

def _build_body_prefix(self, new_var_decls: list[cst.Param]):
def _build_body_prefix(self, new_var_decls: list[cst.Param]) -> list[cst.Assign]:
return [
cst.SimpleStatementLine(
body=[
cst.Assign(
targets=[cst.AssignTarget(target=var_decl.name)],
value=cst.IfExp(
test=cst.Comparison(
left=var_decl.name,
comparisons=[
cst.ComparisonTarget(cst.Is(), cst.Name("None"))
],
),
# In the case of list() or dict(), this particular
# default value has been updated to use the literal
# instead. This does not affect the default
# argument in the function itself.
body=var_decl.default,
orelse=var_decl.name,
),
)
]
cst.Assign(
targets=[cst.AssignTarget(target=var_decl.name)],
value=cst.IfExp(
test=cst.Comparison(
left=var_decl.name,
comparisons=[cst.ComparisonTarget(cst.Is(), cst.Name("None"))],
),
# In the case of list() or dict(), this particular
# default value has been updated to use the literal
# instead. This does not affect the default
# argument in the function itself.
body=var_decl.default,
orelse=var_decl.name,
),
)
for var_decl in new_var_decls
]

def _build_new_body(self, new_var_decls, body):
def _build_new_body(
self, new_var_decls, body: cst.BaseSuite
) -> list[cst.BaseStatement] | list[cst.BaseSmallStatement]:
offset = 0
new_body = []

# Preserve placement of docstring
if body and m.matches(
body[0],
m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
if m.matches(
body.body[0],
m.Expr(value=m.SimpleString())
| m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
):
new_body.append(body[0])
new_body.append(body.body[0])
offset = 1

new_body.extend(self._build_body_prefix(new_var_decls))
new_body.extend(body[offset:])
match body:
case cst.SimpleStatementSuite():
new_body.extend(self._build_body_prefix(new_var_decls))
new_body.extend(body.body[offset:])
case cst.IndentedBlock():
new_body.extend(
[
cst.SimpleStatementLine(body=[stmt])
for stmt in self._build_body_prefix(new_var_decls)
]
)
new_body.extend(body.body[offset:])
return new_body

def _is_abstractmethod(self, node: cst.FunctionDef) -> bool:
Expand All @@ -151,6 +152,14 @@ def _is_abstractmethod(self, node: cst.FunctionDef) -> bool:

return False

def _is_overloaded(self, node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
match decorator.decorator:
case cst.Name("overload"):
return True

return False

def leave_FunctionDef(
self,
original_node: cst.FunctionDef,
Expand All @@ -159,23 +168,53 @@ def leave_FunctionDef(
"""Transforms function definitions with mutable default parameters"""
# TODO: add filter by include or exclude that works for nodes
# that that have different start/end numbers.

(
updated_params,
new_var_decls,
add_annotation,
) = self._gather_and_update_params(original_node, updated_node)
new_body = (
self._build_new_body(new_var_decls, updated_node.body.body)
if not self._is_abstractmethod(original_node)
else updated_node.body.body
)

if new_var_decls:
# If we're adding statements to the body, we know a change took place
self.add_change(original_node, self.change_description)
if add_annotation:
self.add_needed_import("typing", "Optional")

# overloaded methods with empty bodies should only change signature
empty_statement = m.Expr(value=m.Ellipsis()) | m.Pass()
if self._is_overloaded(updated_node) and m.matches(
original_node.body,
m.SimpleStatementSuite(body=[empty_statement])
| m.IndentedBlock(body=[m.SimpleStatementLine(body=[empty_statement])]),
):
return updated_node.with_changes(
params=updated_node.params.with_changes(params=updated_params)
)

new_body = (
self._build_new_body(new_var_decls, updated_node.body)
if not self._is_abstractmethod(original_node)
else updated_node.body.body
)

return updated_node.with_changes(
params=updated_node.params.with_changes(params=updated_params),
body=updated_node.body.with_changes(body=new_body),
body=(
updated_node.body.with_changes(body=new_body)
if new_body
else updated_node.body
),
)


FixMutableParams = CoreCodemod(
metadata=Metadata(
name="fix-mutable-params",
summary="Replace Mutable Default Parameters",
review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW,
references=[],
),
transformer=LibcstTransformerPipeline(FixMutableParamsTransformer),
detector=None,
)
36 changes: 36 additions & 0 deletions tests/codemods/test_fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def foo(bar=None):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("mutable", ["[]", "{}", "set()"])
def test_fix_single_arg_suite(self, tmpdir, mutable):
input_code = f"""
def foo(bar={mutable}): print(bar)
"""
expected_output = f"""
def foo(bar=None): bar = {mutable} if bar is None else bar; print(bar)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("mutable", ["[]", "{}", "set()"])
def test_fix_single_arg_method(self, tmpdir, mutable):
input_code = f"""
Expand Down Expand Up @@ -90,6 +100,32 @@ def foo(bar=None, baz=None):
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_overloaded(self, tmpdir):
input_code = """
from typing import overload
@overload
def foo(a : list[str] = []) -> str:
...
@overload
def foo(a : list[int] = []) -> int:
...
def foo(a : list[int] | list[str] = []) -> int|str:
return 0
"""
expected_output = """
from typing import Optional, overload
@overload
def foo(a : Optional[list[str]] = None) -> str:
...
@overload
def foo(a : Optional[list[int]] = None) -> int:
...
def foo(a : Optional[list[int] | list[str]] = None) -> int|str:
a = [] if a is None else a
return 0
"""
self.run_and_assert(tmpdir, input_code, expected_output, num_changes=3)

def test_fix_multiple_args_mixed(self, tmpdir):
input_code = """
def foo(bar=[], x="hello", baz={"foo": 42}, biz=set(), boz=list(), buz={1, 2, 3}):
Expand Down
Loading