diff --git a/integration_tests/test_fix_task_instantiation.py b/integration_tests/test_fix_task_instantiation.py new file mode 100644 index 00000000..2acbaca9 --- /dev/null +++ b/integration_tests/test_fix_task_instantiation.py @@ -0,0 +1,39 @@ +from core_codemods.fix_async_task_instantiation import FixAsyncTaskInstantiation +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestFixAsyncTaskInstantiation(BaseIntegrationTest): + codemod = FixAsyncTaskInstantiation + code_path = "tests/samples/fix_async_task_instantiation.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, + [ + ( + 7, + """ task = asyncio.create_task(my_coroutine(), name="my task")\n""", + ), + ], + ) + + # fmt: off + expected_diff =( + """--- \n""" + """+++ \n""" + """@@ -5,7 +5,7 @@\n""" + """ print("Task completed")\n""" + """ \n""" + """ async def main():\n""" + """- task = asyncio.Task(my_coroutine(), name="my task")\n""" + """+ task = asyncio.create_task(my_coroutine(), name="my task")\n""" + """ await task\n""" + """ \n""" + """ asyncio.run(main())\n""" + ) + # fmt: on + + expected_line_change = "8" + change_description = FixAsyncTaskInstantiation.change_description + num_changed_files = 1 diff --git a/src/codemodder/codemods/utils.py b/src/codemodder/codemods/utils.py index 0368415e..3538121b 100644 --- a/src/codemodder/codemods/utils.py +++ b/src/codemodder/codemods/utils.py @@ -17,6 +17,9 @@ class BaseType(Enum): LIST = 2 STRING = 3 BYTES = 4 + NONE = 5 + TRUE = 6 + FALSE = 7 # pylint: disable-next=R0911 @@ -26,6 +29,12 @@ def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]: """ # The current implementation covers some common cases and is in no way complete match node: + case cst.Name(value="None"): + return BaseType.NONE + case cst.Name(value="True"): + return BaseType.TRUE + case cst.Name(value="False"): + return BaseType.FALSE case ( cst.Integer() | cst.Imaginary() diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index e2256edd..dc1c58d3 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -230,6 +230,10 @@ class DocMetadata: importance="Medium", guidance_explained="While string concatenation inside a sequence iterable is likely a mistake, there are instances when you may choose to use them..", ), + "fix-async-task-instantiation": DocMetadata( + importance="Low", + guidance_explained="Manual instantiation of `asyncio.Task` is discouraged. We believe this change is safe and will not cause any issues.", + ), } METADATA = CORE_METADATA | { diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 09f56835..ea4235fa 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -61,6 +61,7 @@ from .sonar.sonar_django_json_response_type import SonarDjangoJsonResponseType from .lazy_logging import LazyLogging from .str_concat_in_seq_literal import StrConcatInSeqLiteral +from .fix_async_task_instantiation import FixAsyncTaskInstantiation registry = CodemodCollection( origin="pixee", @@ -118,6 +119,7 @@ FixAssertTuple, LazyLogging, StrConcatInSeqLiteral, + FixAsyncTaskInstantiation, ], ) diff --git a/src/core_codemods/docs/pixee_python_fix-async-task-instantiation.md b/src/core_codemods/docs/pixee_python_fix-async-task-instantiation.md new file mode 100644 index 00000000..3d856e0d --- /dev/null +++ b/src/core_codemods/docs/pixee_python_fix-async-task-instantiation.md @@ -0,0 +1,9 @@ +The `asyncio` [documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task) explicitly discourages manual instantiation of a `Task` instance and instead recommends calling `create_task`. This keeps your code in line with recommended best practices and promotes maintainability. + +Our changes look like the following: +```diff + import asyncio + +- task = asyncio.Task(my_coroutine(), name="my task") ++ task = asyncio.create_task(my_coroutine(), name="my task") +``` diff --git a/src/core_codemods/fix_async_task_instantiation.py b/src/core_codemods/fix_async_task_instantiation.py new file mode 100644 index 00000000..b7f5e486 --- /dev/null +++ b/src/core_codemods/fix_async_task_instantiation.py @@ -0,0 +1,169 @@ +import libcst as cst +from libcst import MaybeSentinel +from typing import Optional +from core_codemods.api import Metadata, ReviewGuidance, SimpleCodemod, Reference +from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin +from codemodder.codemods.utils import BaseType, infer_expression_type + + +class FixAsyncTaskInstantiation(SimpleCodemod, NameAndAncestorResolutionMixin): + metadata = Metadata( + name="fix-async-task-instantiation", + summary="Use High-Level `asyncio` API Functions to Create Tasks", + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, + references=[ + Reference( + url="https://docs.python.org/3/library/asyncio-task.html#asyncio.Task" + ), + ], + ) + change_description = "Replace instantiation of `asyncio.Task` with higher-level functions to create tasks." + _module_name = "asyncio" + + # pylint: disable=too-many-return-statements + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + if not self.filter_by_path_includes_or_excludes( + self.node_position(original_node) + ): + return updated_node + + if self.find_base_name(original_node) != "asyncio.Task": + return updated_node + coroutine_arg = original_node.args[0] + loop_arg, eager_start_arg, other_args = self._split_args(original_node.args[1:]) + + loop_type = ( + infer_expression_type(self.resolve_expression(loop_arg.value)) + if loop_arg + else None + ) + + eager_start_type = ( + infer_expression_type(self.resolve_expression(eager_start_arg.value)) + if eager_start_arg + else None + ) + + if eager_start_type == BaseType.TRUE: + if not loop_arg or self._is_invalid_loop_value(loop_type): + # asking for eager_start without a loop or incorrectly setting loop is bad. + # We won't do anything. + return updated_node + + loop_arg = loop_arg.with_changes(keyword=None, equal=MaybeSentinel.DEFAULT) + return self.node_eager_task( + original_node, + updated_node, + replacement_args=[loop_arg, coroutine_arg] + other_args, + ) + + if loop_arg: + if loop_type == BaseType.NONE: + return self.node_create_task( + original_node, + updated_node, + replacement_args=[coroutine_arg] + other_args, + ) + if self._is_invalid_loop_value(loop_type): + # incorrectly assigned loop kwarg to something that is not a loop. + # We won't do anything. + return updated_node + + return self.node_loop_create_task( + original_node, coroutine_arg, loop_arg, other_args + ) + return self.node_create_task( + original_node, updated_node, replacement_args=[coroutine_arg] + other_args + ) + + def node_create_task( + self, + original_node: cst.Call, + updated_node: cst.Call, + replacement_args=list[cst.Arg], + ) -> cst.Call: + """Convert `asyncio.Task(...)` to `asyncio.create_task(...)`""" + self.report_change(original_node) + maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) + if (maybe_name := maybe_name or self._module_name) == self._module_name: + self.add_needed_import(self._module_name) + self.remove_unused_import(original_node) + + if len(replacement_args) == 1: + replacement_args[0] = replacement_args[0].with_changes( + comma=MaybeSentinel.DEFAULT + ) + return self.update_call_target( + updated_node, maybe_name, "create_task", replacement_args=replacement_args + ) + + def node_eager_task( + self, + original_node: cst.Call, + updated_node: cst.Call, + replacement_args=list[cst.Arg], + ) -> cst.Call: + """Convert `asyncio.Task(...)` to `asyncio.eager_task_factory(loop, coro...)`""" + self.report_change(original_node) + maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) + if (maybe_name := maybe_name or self._module_name) == self._module_name: + self.add_needed_import(self._module_name) + self.remove_unused_import(original_node) + return self.update_call_target( + updated_node, + maybe_name, + "eager_task_factory", + replacement_args=replacement_args, + ) + + def node_loop_create_task( + self, + original_node: cst.Call, + coroutine_arg: cst.Arg, + loop_arg: cst.Arg, + other_args: list[cst.Arg], + ) -> cst.Call: + """Convert `asyncio.Task(..., loop=loop,...)` to `loop.create_task(...)`""" + self.report_change(original_node) + coroutine_arg = coroutine_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT) + loop_attr = loop_arg.value + new_call = cst.Call( + func=cst.Attribute(value=loop_attr, attr=cst.Name("create_task")), + args=[coroutine_arg] + other_args, + ) + self.remove_unused_import(original_node) + return new_call + + def _split_args( + self, args: list[cst.Arg] + ) -> tuple[Optional[cst.Arg], Optional[cst.Arg], list[cst.Arg]]: + """Find the loop kwarg and the eager_start kwarg from a list of args. + Return any args or non-None kwargs. + """ + loop_arg, eager_start_arg = None, None + other_args = [] + for arg in args: + match arg: + case cst.Arg(keyword=cst.Name(value="loop")): + loop_arg = arg + case cst.Arg(keyword=cst.Name(value="eager_start")): + eager_start_arg = arg + case cst.Arg(keyword=cst.Name() as k) if k.value != "None": + # keep kwarg that are not set to None + other_args.append(arg) + case cst.Arg(keyword=None): + # keep post args + other_args.append(arg) + + return loop_arg, eager_start_arg, other_args + + def _is_invalid_loop_value(self, loop_type): + return loop_type in ( + BaseType.NONE, + BaseType.NUMBER, + BaseType.LIST, + BaseType.STRING, + BaseType.BYTES, + BaseType.TRUE, + BaseType.FALSE, + ) diff --git a/tests/codemods/test_async_fix_task_instantiation.py b/tests/codemods/test_async_fix_task_instantiation.py new file mode 100644 index 00000000..54065f3f --- /dev/null +++ b/tests/codemods/test_async_fix_task_instantiation.py @@ -0,0 +1,257 @@ +import pytest +from core_codemods.fix_async_task_instantiation import FixAsyncTaskInstantiation +from tests.codemods.base_codemod_test import BaseCodemodTest + + +class TestFixAsyncTaskInstantiation(BaseCodemodTest): + codemod = FixAsyncTaskInstantiation + + @pytest.mark.parametrize( + "input_code,expected_output", + [ + ( + """ + import asyncio + asyncio.Task(coro(1, 2)) + """, + """ + import asyncio + asyncio.create_task(coro(1, 2)) + """, + ), + ( + """ + import asyncio + async def coro(*args): + print(args) + + my_coro = coro(1, 2) + asyncio.Task(my_coro) + """, + """ + import asyncio + async def coro(*args): + print(args) + + my_coro = coro(1, 2) + asyncio.create_task(my_coro) + """, + ), + ( + """ + import asyncio + my_loop = asyncio.get_event_loop() + asyncio.Task(coro(1, 2), loop=my_loop) + """, + """ + import asyncio + my_loop = asyncio.get_event_loop() + my_loop.create_task(coro(1, 2)) + """, + ), + ], + ) + def test_import(self, tmpdir, input_code, expected_output): + self.run_and_assert(tmpdir, input_code, expected_output) + + @pytest.mark.parametrize( + "input_code,expected_output", + [ + ( + """ + from asyncio import Task + Task(coro(1, 2)) + """, + """ + import asyncio + + asyncio.create_task(coro(1, 2)) + """, + ), + ( + """ + from asyncio import Task, get_event_loop + my_loop = get_event_loop() + Task(coro(1, 2), loop=my_loop) + """, + """ + from asyncio import get_event_loop + my_loop = get_event_loop() + my_loop.create_task(coro(1, 2)) + """, + ), + ], + ) + def test_from_import(self, tmpdir, input_code, expected_output): + self.run_and_assert(tmpdir, input_code, expected_output) + + @pytest.mark.parametrize( + "input_code,expected_output", + [ + ( + """ + from asyncio import Task as taskInit + taskInit(coro(1, 2)) + """, + """ + import asyncio + + asyncio.create_task(coro(1, 2)) + """, + ), + ( + """ + from asyncio import get_event_loop, Task as taskInit + my_loop = get_event_loop() + taskInit(coro(1, 2), loop=my_loop) + """, + """ + from asyncio import get_event_loop + my_loop = get_event_loop() + my_loop.create_task(coro(1, 2)) + """, + ), + ], + ) + def test_import_alias(self, tmpdir, input_code, expected_output): + self.run_and_assert(tmpdir, input_code, expected_output) + + @pytest.mark.parametrize( + "input_code,expected_output", + [ + ( + """ + import asyncio + asyncio.Task(coro(1, 2), name='task') + """, + """ + import asyncio + asyncio.create_task(coro(1, 2), name='task') + """, + ), + ( + """ + import asyncio + my_loop = asyncio.get_event_loop() + asyncio.Task(coro(1, 2), name='task', loop=my_loop, context=None) + """, + """ + import asyncio + my_loop = asyncio.get_event_loop() + my_loop.create_task(coro(1, 2), name='task', context=None) + """, + ), + ( + """ + import asyncio + asyncio.Task(coro(1, 2), loop=None, eager_start=None) + """, + """ + import asyncio + asyncio.create_task(coro(1, 2)) + """, + ), + ( + """ + import asyncio + asyncio.Task(coro(1, 2), loop=None, eager_start=False) + """, + """ + import asyncio + asyncio.create_task(coro(1, 2)) + """, + ), + ( + """ + import asyncio + asyncio.Task(coro(1, 2), eager_start=False, name='task') + """, + """ + import asyncio + asyncio.create_task(coro(1, 2), name='task') + """, + ), + ( + """ + import asyncio + asyncio.Task(coro(1, 2), eager_start=True, name='task') + """, + """ + import asyncio + asyncio.Task(coro(1, 2), eager_start=True, name='task') + """, + ), + ( + """ + import asyncio + asyncio.Task(coro(1, 2), eager_start=True, name='task') + """, + """ + import asyncio + asyncio.Task(coro(1, 2), eager_start=True, name='task') + """, + ), + ( + """ + import asyncio + my_loop = asyncio.get_event_loop() + asyncio.Task(coro(1, 2), loop=my_loop, eager_start=True, name='task') + """, + """ + import asyncio + my_loop = asyncio.get_event_loop() + asyncio.eager_task_factory(my_loop, coro(1, 2), name='task') + """, + ), + ], + ) + def test_with_other_kwargs(self, tmpdir, input_code, expected_output): + self.run_and_assert(tmpdir, input_code, expected_output) + + @pytest.mark.parametrize("loop_value", ["None", "True", '"gibberish"', 10]) + def test_loop_kwarg_variations(self, tmpdir, loop_value): + input_code = ( + output_code + ) = f""" + import asyncio + asyncio.Task(coro(1, 2), loop={loop_value}) + """ + if loop_value == "None": + output_code = """ + import asyncio + asyncio.create_task(coro(1, 2)) + """ + self.run_and_assert(tmpdir, input_code, output_code) + + def test_asyncio_script(self, tmpdir): + input_code = """ + import asyncio + + async def my_coroutine(): + await asyncio.sleep(1) + print("Task completed") + + async def main(): + loop = asyncio.get_running_loop() + task = asyncio.Task(my_coroutine(), loop=loop) + await task + task_2 = asyncio.Task(my_coroutine()) + await task_2 + asyncio.run(main()) + """ + output_code = """ + import asyncio + + async def my_coroutine(): + await asyncio.sleep(1) + print("Task completed") + + async def main(): + loop = asyncio.get_running_loop() + task = loop.create_task(my_coroutine()) + await task + task_2 = asyncio.create_task(my_coroutine()) + await task_2 + asyncio.run(main()) + """ + self.run_and_assert(tmpdir, input_code, output_code, num_changes=2) diff --git a/tests/samples/fix_async_task_instantiation.py b/tests/samples/fix_async_task_instantiation.py new file mode 100644 index 00000000..ee7e5144 --- /dev/null +++ b/tests/samples/fix_async_task_instantiation.py @@ -0,0 +1,11 @@ +import asyncio + +async def my_coroutine(): + await asyncio.sleep(1) + print("Task completed") + +async def main(): + task = asyncio.Task(my_coroutine(), name="my task") + await task + +asyncio.run(main()) diff --git a/tests/test_basetype.py b/tests/test_basetype.py index 2096a31e..c4df21b4 100644 --- a/tests/test_basetype.py +++ b/tests/test_basetype.py @@ -1,4 +1,5 @@ import libcst as cst +import pytest from codemodder.codemods.utils import BaseType, infer_expression_type @@ -34,3 +35,15 @@ def test_if_numbers(self): def test_if_numbers2(self): e = cst.parse_expression("float(1) if True else len([1,2])") assert infer_expression_type(e) == BaseType.NUMBER + + @pytest.mark.parametrize("code", ["True", "False"]) + def test_bool(self, code): + e = cst.parse_expression(code) + if code == "True": + assert infer_expression_type(e) == BaseType.TRUE + else: + assert infer_expression_type(e) == BaseType.FALSE + + def test_none(self): + e = cst.parse_expression("None") + assert infer_expression_type(e) == BaseType.NONE