Skip to content

Commit

Permalink
Add support/tests for updating type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Sep 27, 2023
1 parent 46a399d commit 345ba40
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/codemodder/codemods/fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,28 @@ def __init__(self, *args, **kwargs):
# Looking for list() or dict()
self._matches_builtin = m.Call(func=m.Name("list") | m.Name("dict"))

def _create_annotation(self, orig: cst.Param, updated: cst.Param):
return (
updated.annotation.with_changes(
annotation=cst.Subscript(
value=cst.Name("Optional"),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=orig.annotation.annotation)
)
],
)
)
if updated.annotation is not None
else None
)

def _gather_and_update_params(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
):
updated_params = []
new_var_decls = []
add_annotation = False

# Iterate over all original/update parameters in parallel
for orig, updated in zip(
Expand All @@ -59,11 +76,20 @@ def _gather_and_update_params(
)
needs_update = True

annotation = (
self._create_annotation(orig, updated) if needs_update else None
)
add_annotation = add_annotation or annotation is not None
updated_params.append(
updated.with_changes(default=cst.Name("None")) if needs_update else orig
updated.with_changes(
default=cst.Name("None"),
annotation=annotation,
)
if needs_update
else updated,
)

return updated_params, new_var_decls
return updated_params, new_var_decls, add_annotation

def _build_body_prefix(self, new_var_decls: list[cst.Param]):
return [
Expand Down Expand Up @@ -97,13 +123,17 @@ def leave_FunctionDef(
updated_node: cst.FunctionDef,
):
"""Transforms function definitions with mutable default parameters"""
updated_params, new_var_decls = self._gather_and_update_params(
original_node, updated_node
)
(
updated_params,
new_var_decls,
add_annotation,
) = self._gather_and_update_params(original_node, updated_node)
# Add any new variable declarations to the top of the function body
if body_prefix := self._build_body_prefix(new_var_decls):
# If we're adding statements to the body, we know a change took place
self.add_change(original_node, self.CHANGE_DESCRIPTION)
if add_annotation:
self.add_needed_import("typing", "Optional")

new_body = tuple(body_prefix) + updated_node.body.body
return updated_node.with_changes(
Expand Down
72 changes: 72 additions & 0 deletions tests/codemods/test_fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,77 @@ def foo(bar="hello", baz=None):
baz = {} if baz is None else baz
print(bar)
print(baz)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize(
"mutable,annotation", [("[]", "List"), ("{}", "Dict"), ("set()", "Set")]
)
def test_fix_with_type_annotation(self, tmpdir, mutable, annotation):
input_code = f"""
from typing import {annotation}
def foo(bar: {annotation}[int] = {mutable}):
print(bar)
"""
expected_output = f"""
from typing import Optional, {annotation}
def foo(bar: Optional[{annotation}[int]] = None):
bar = {mutable} if bar is None else bar
print(bar)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize(
"mutable,annotation", [("[]", "list"), ("{}", "dict"), ("set()", "set")]
)
def test_fix_with_type_annotation_new_import(self, tmpdir, mutable, annotation):
input_code = f"""
def foo(bar: {annotation}[int] = {mutable}):
print(bar)
"""
expected_output = f"""
from typing import Optional
def foo(bar: Optional[{annotation}[int]] = None):
bar = {mutable} if bar is None else bar
print(bar)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_one_type_annotation(self, tmpdir):
input_code = """
from typing import List
def foo(x = [], y: List[int] = [], z = {}):
print(x, y, z)
"""
expected_output = """
from typing import Optional, List
def foo(x = None, y: Optional[List[int]] = None, z = None):
x = [] if x is None else x
y = [] if y is None else y
z = {} if z is None else z
print(x, y, z)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_fix_multiple_type_annotations(self, tmpdir):
input_code = """
from typing import Dict, List
def foo(x = [], y: List[int] = [], z: Dict[str, int] = {}):
print(x, y, z)
"""
expected_output = """
from typing import Optional, Dict, List
def foo(x = None, y: Optional[List[int]] = None, z: Optional[Dict[str, int]] = None):
x = [] if x is None else x
y = [] if y is None else y
z = {} if z is None else z
print(x, y, z)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

0 comments on commit 345ba40

Please sign in to comment.