Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codemod to fix asyncio.Task #248

Merged
merged 8 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions integration_tests/test_fix_task_instantiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from core_codemods.fix_async_task_instantiation import FixAsyncTaskInstantiation
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestFixAsyncTaskInstantiation(BaseIntegrationTest):
codemod = FixAsyncTaskInstantiation
code_path = "tests/samples/fix_async_task_instantiation.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path,
[
(
7,
""" task = asyncio.create_task(my_coroutine(), name="my task")\n""",
),
],
)

# fmt: off
expected_diff =(
"""--- \n"""
"""+++ \n"""
"""@@ -5,7 +5,7 @@\n"""
""" print("Task completed")\n"""
""" \n"""
""" async def main():\n"""
"""- task = asyncio.Task(my_coroutine(), name="my task")\n"""
"""+ task = asyncio.create_task(my_coroutine(), name="my task")\n"""
""" await task\n"""
""" \n"""
""" asyncio.run(main())\n"""
)
# fmt: on

expected_line_change = "8"
change_description = FixAsyncTaskInstantiation.change_description
num_changed_files = 1
9 changes: 9 additions & 0 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class BaseType(Enum):
LIST = 2
STRING = 3
BYTES = 4
NONE = 5
TRUE = 6
FALSE = 7


# pylint: disable-next=R0911
Expand All @@ -26,6 +29,12 @@ def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]:
"""
# The current implementation covers some common cases and is in no way complete
match node:
case cst.Name(value="None"):
return BaseType.NONE
case cst.Name(value="True"):
return BaseType.TRUE
case cst.Name(value="False"):
return BaseType.FALSE
case (
cst.Integer()
| cst.Imaginary()
Expand Down
4 changes: 4 additions & 0 deletions src/codemodder/scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ class DocMetadata:
importance="Medium",
guidance_explained="While string concatenation inside a sequence iterable is likely a mistake, there are instances when you may choose to use them..",
),
"fix-async-task-instantiation": DocMetadata(
importance="Low",
guidance_explained="Manual instantiation of `asyncio.Task` is discouraged. We believe this change is safe and will not cause any issues.",
),
}

METADATA = CORE_METADATA | {
Expand Down
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .sonar.sonar_django_json_response_type import SonarDjangoJsonResponseType
from .lazy_logging import LazyLogging
from .str_concat_in_seq_literal import StrConcatInSeqLiteral
from .fix_async_task_instantiation import FixAsyncTaskInstantiation

registry = CodemodCollection(
origin="pixee",
Expand Down Expand Up @@ -118,6 +119,7 @@
FixAssertTuple,
LazyLogging,
StrConcatInSeqLiteral,
FixAsyncTaskInstantiation,
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The `asyncio` [documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task) explicitly discourages manual instantiation of a `Task` instance and instead recommends calling `create_task`. This keeps your code in line with recommended best practices and promotes maintainability.

Our changes look like the following:
```diff
import asyncio

- task = asyncio.Task(my_coroutine(), name="my task")
+ task = asyncio.create_task(my_coroutine(), name="my task")
```
169 changes: 169 additions & 0 deletions src/core_codemods/fix_async_task_instantiation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import libcst as cst
from libcst import MaybeSentinel
from typing import Optional
from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod, Reference
from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin
from codemodder.codemods.utils import BaseType, infer_expression_type


class FixAsyncTaskInstantiation(SimpleCodemod, NameAndAncestorResolutionMixin):
metadata = Metadata(
name="fix-async-task-instantiation",
summary="Use High-Level `asyncio` API Functions to Create Tasks",
review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW,
references=[
Reference(
url="https://docs.python.org/3/library/asyncio-task.html#asyncio.Task"
),
],
)
change_description = "Replace instantiation of `asyncio.Task` with higher-level functions to create tasks."
_module_name = "asyncio"

# pylint: disable=too-many-return-statements
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
if not self.filter_by_path_includes_or_excludes(
self.node_position(original_node)
):
return updated_node

if self.find_base_name(original_node) != "asyncio.Task":
return updated_node
coroutine_arg = original_node.args[0]
loop_arg, eager_start_arg, other_args = self._split_args(original_node.args[1:])

loop_type = (
infer_expression_type(self.resolve_expression(loop_arg.value))
if loop_arg
else None
)

eager_start_type = (
infer_expression_type(self.resolve_expression(eager_start_arg.value))
if eager_start_arg
else None
)

if eager_start_type == BaseType.TRUE:
if not loop_arg or self._is_invalid_loop_value(loop_type):
# asking for eager_start without a loop or incorrectly setting loop is bad.
# We won't do anything.
return updated_node

loop_arg = loop_arg.with_changes(keyword=None, equal=MaybeSentinel.DEFAULT)
return self.node_eager_task(
original_node,
updated_node,
replacement_args=[loop_arg, coroutine_arg] + other_args,
)

if loop_arg:
if loop_type == BaseType.NONE:
return self.node_create_task(
original_node,
updated_node,
replacement_args=[coroutine_arg] + other_args,
)
if self._is_invalid_loop_value(loop_type):
# incorrectly assigned loop kwarg to something that is not a loop.
# We won't do anything.
return updated_node

return self.node_loop_create_task(
original_node, coroutine_arg, loop_arg, other_args
)
return self.node_create_task(
original_node, updated_node, replacement_args=[coroutine_arg] + other_args
)

def node_create_task(
self,
original_node: cst.Call,
updated_node: cst.Call,
replacement_args=list[cst.Arg],
) -> cst.Call:
"""Convert `asyncio.Task(...)` to `asyncio.create_task(...)`"""
self.report_change(original_node)
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)

if len(replacement_args) == 1:
replacement_args[0] = replacement_args[0].with_changes(
comma=MaybeSentinel.DEFAULT
)
return self.update_call_target(
updated_node, maybe_name, "create_task", replacement_args=replacement_args
)

def node_eager_task(
self,
original_node: cst.Call,
updated_node: cst.Call,
replacement_args=list[cst.Arg],
) -> cst.Call:
"""Convert `asyncio.Task(...)` to `asyncio.eager_task_factory(loop, coro...)`"""
self.report_change(original_node)
maybe_name = self.get_aliased_prefix_name(original_node, self._module_name)
if (maybe_name := maybe_name or self._module_name) == self._module_name:
self.add_needed_import(self._module_name)
self.remove_unused_import(original_node)
return self.update_call_target(
updated_node,
maybe_name,
"eager_task_factory",
replacement_args=replacement_args,
)

def node_loop_create_task(
self,
original_node: cst.Call,
coroutine_arg: cst.Arg,
loop_arg: cst.Arg,
other_args: list[cst.Arg],
) -> cst.Call:
"""Convert `asyncio.Task(..., loop=loop,...)` to `loop.create_task(...)`"""
self.report_change(original_node)
coroutine_arg = coroutine_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT)
loop_attr = loop_arg.value
new_call = cst.Call(
func=cst.Attribute(value=loop_attr, attr=cst.Name("create_task")),
args=[coroutine_arg] + other_args,
)
self.remove_unused_import(original_node)
return new_call

def _split_args(
self, args: list[cst.Arg]
) -> tuple[Optional[cst.Arg], Optional[cst.Arg], list[cst.Arg]]:
"""Find the loop kwarg and the eager_start kwarg from a list of args.
Return any args or non-None kwargs.
"""
loop_arg, eager_start_arg = None, None
other_args = []
for arg in args:
match arg:
case cst.Arg(keyword=cst.Name(value="loop")):
loop_arg = arg
case cst.Arg(keyword=cst.Name(value="eager_start")):
eager_start_arg = arg
case cst.Arg(keyword=cst.Name() as k) if k.value != "None":
# keep kwarg that are not set to None
other_args.append(arg)
case cst.Arg(keyword=None):
# keep post args
other_args.append(arg)

return loop_arg, eager_start_arg, other_args

def _is_invalid_loop_value(self, loop_type):
return loop_type in (
BaseType.NONE,
BaseType.NUMBER,
BaseType.LIST,
BaseType.STRING,
BaseType.BYTES,
BaseType.TRUE,
BaseType.FALSE,
)
Loading
Loading