Skip to content

Commit

Permalink
fix task init codemod can handle different loop types
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Feb 9, 2024
1 parent c84bc3f commit 97c5e00
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 5 deletions.
6 changes: 6 additions & 0 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class BaseType(Enum):
LIST = 2
STRING = 3
BYTES = 4
NONE = 5
BOOL = 6


# pylint: disable-next=R0911
Expand All @@ -26,6 +28,10 @@ def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]:
"""
# The current implementation covers some common cases and is in no way complete
match node:
case cst.Name(value="None"):
return BaseType.NONE
case cst.Name(value="True") | cst.Name(value="False"):
return BaseType.BOOL
case (
cst.Integer()
| cst.Imaginary()
Expand Down
26 changes: 22 additions & 4 deletions src/core_codemods/fix_task_instantiation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import libcst as cst
from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod, Reference
from codemodder.codemods.utils_mixin import NameResolutionMixin, AncestorPatternsMixin
from typing import Optional
from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod, Reference
from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin
from codemodder.codemods.utils import BaseType, infer_expression_type


class FixTaskInstantiation(SimpleCodemod, NameResolutionMixin, AncestorPatternsMixin):
class FixTaskInstantiation(SimpleCodemod, NameAndAncestorResolutionMixin):
metadata = Metadata(
name="fix-task-instantiation",
summary="TODOReplace Comparisons to Empty Sequence with Implicit Boolean Comparison",
Expand All @@ -27,9 +28,24 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
return updated_node

if self.find_base_name(original_node) == "asyncio.Task":
self.report_change(original_node)
loop_arg, other_args = self._find_loop_arg(original_node)
if loop_arg:
loop_type = infer_expression_type(
self.resolve_expression(loop_arg.value)
)
if loop_type == BaseType.NONE:
return self.node_create_task(original_node, updated_node)
elif loop_type in (
BaseType.NUMBER,
BaseType.LIST,
BaseType.STRING,
BaseType.BYTES,
BaseType.BOOL,
):
# User incorrectly assigned loop to something that is not a loop.
# We won't do anything.
return updated_node

coroutine_arg = original_node.args[0]
return self.node_loop_create_task(
original_node, coroutine_arg, loop_arg, other_args
Expand All @@ -40,6 +56,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
def node_create_task(
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.Call:
self.report_change(original_node)
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
Expand All @@ -54,6 +71,7 @@ def node_loop_create_task(
other_args: list[cst.Arg],
) -> cst.Call:
"""todo: document"""
self.report_change(original_node)
coroutine_arg = coroutine_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
loop_attr = loop_arg.value
new_call = cst.Call(
Expand Down
17 changes: 16 additions & 1 deletion tests/codemods/test_fix_task_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,24 @@ def test_import_alias(self, tmpdir, input_code, expected_output):
),
],
)
def test_with_kwargs(self, tmpdir, input_code, expected_output):
def test_with_other_kwargs(self, tmpdir, input_code, expected_output):
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("loop_value", ["None", "True", '"gibberish"', 10])
def test_loop_kwarg_variations(self, tmpdir, loop_value):
input_code = (
output_code
) = f"""
import asyncio
asyncio.Task(coro(1, 2), loop={loop_value})
"""
if loop_value == "None":
output_code = """
import asyncio
asyncio.create_task(coro(1, 2), loop=None)
"""
self.run_and_assert(tmpdir, input_code, output_code)

def test_asyncio_script(self, tmpdir):
input_code = """
import asyncio
Expand Down
10 changes: 10 additions & 0 deletions tests/test_basetype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import libcst as cst
import pytest
from codemodder.codemods.utils import BaseType, infer_expression_type


Expand Down Expand Up @@ -34,3 +35,12 @@ def test_if_numbers(self):
def test_if_numbers2(self):
e = cst.parse_expression("float(1) if True else len([1,2])")
assert infer_expression_type(e) == BaseType.NUMBER

@pytest.mark.parametrize("code", ["True", "False"])
def test_bool(self, code):
e = cst.parse_expression(code)
assert infer_expression_type(e) == BaseType.BOOL

def test_none(self):
e = cst.parse_expression("None")
assert infer_expression_type(e) == BaseType.NONE

0 comments on commit 97c5e00

Please sign in to comment.