diff --git a/integration_tests/test_fix_task_instantiation.py b/integration_tests/test_fix_task_instantiation.py new file mode 100644 index 000000000..b44867593 --- /dev/null +++ b/integration_tests/test_fix_task_instantiation.py @@ -0,0 +1,39 @@ +from core_codemods.fix_task_instantiation import FixTaskInstantiation +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestFixTaskInstantiation(BaseIntegrationTest): + codemod = FixTaskInstantiation + code_path = "tests/samples/fix_task_instantiation.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, + [ + ( + 7, + """ task = asyncio.create_task(my_coroutine(), name="my task")\n""", + ), + ], + ) + + # fmt: off + expected_diff =( + """--- \n""" + """+++ \n""" + """@@ -5,7 +5,7 @@\n""" + """ print("Task completed")\n""" + """ \n""" + """ async def main():\n""" + """- task = asyncio.Task(my_coroutine(), name="my task")\n""" + """+ task = asyncio.create_task(my_coroutine(), name="my task")\n""" + """ await task\n""" + """ \n""" + """ asyncio.run(main())\n""" + ) + # fmt: on + + expected_line_change = "8" + change_description = FixTaskInstantiation.change_description + num_changed_files = 1 diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index 20e3e9c8a..0983009a9 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -226,6 +226,10 @@ class DocMetadata: importance="Medium", guidance_explained="We believe this change is safe and will not cause any issues.", ), + "fix-task-instantiation": DocMetadata( + importance="Low", + guidance_explained="Manual instantiation of `asyncio.Task` is discouraged. We believe this change is safe and will not cause any issues.", + ), } METADATA = CORE_METADATA | { diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 2da21231f..5ed042966 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -60,6 +60,7 @@ from .sonar.sonar_flask_json_response_type import SonarFlaskJsonResponseType from .sonar.sonar_django_json_response_type import SonarDjangoJsonResponseType from .lazy_logging import LazyLogging +from .fix_task_instantiation import FixTaskInstantiation registry = CodemodCollection( origin="pixee", @@ -116,6 +117,7 @@ RemoveAssertionInPytestRaises, FixAssertTuple, LazyLogging, + FixTaskInstantiation, ], ) diff --git a/src/core_codemods/docs/pixee_python_fix-task-instantiation.md b/src/core_codemods/docs/pixee_python_fix-task-instantiation.md new file mode 100644 index 000000000..73b5c22ae --- /dev/null +++ b/src/core_codemods/docs/pixee_python_fix-task-instantiation.md @@ -0,0 +1,9 @@ +The `asyncio` [documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task) explicitly discourages manual instantiation of a `Task` instance and instead recommends calling `create_task`. + +Our changes look like the following: +```diff + import asyncio + +- task = asyncio.Task(my_coroutine(), name="my task") ++ task = asyncio.create_task(my_coroutine(), name="my task") +``` diff --git a/src/core_codemods/fix_task_instantiation.py b/src/core_codemods/fix_task_instantiation.py index 802a41ccd..a27a93ab7 100644 --- a/src/core_codemods/fix_task_instantiation.py +++ b/src/core_codemods/fix_task_instantiation.py @@ -8,17 +8,15 @@ class FixTaskInstantiation(SimpleCodemod, NameAndAncestorResolutionMixin): metadata = Metadata( name="fix-task-instantiation", - summary="TODOReplace Comparisons to Empty Sequence with Implicit Boolean Comparison", - review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, + summary="Use high-level `asyncio.create_task` API", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, references=[ Reference( - url="todo: https://docs.python.org/3/library/stdtypes.html#truth-value-testing" + url="https://docs.python.org/3/library/asyncio-task.html#asyncio.Task" ), ], ) - change_description = ( - "TODO: Replace comparisons to empty sequence with implicit boolean comparison." - ) + change_description = "Replace instantiation of `asyncio.Task` with `create_task`" _module_name = "asyncio" def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: @@ -35,14 +33,14 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal ) if loop_type == BaseType.NONE: return self.node_create_task(original_node, updated_node) - elif loop_type in ( + if loop_type in ( BaseType.NUMBER, BaseType.LIST, BaseType.STRING, BaseType.BYTES, BaseType.BOOL, ): - # User incorrectly assigned loop to something that is not a loop. + # incorrectly assigned loop kwarg to something that is not a loop. # We won't do anything. return updated_node @@ -56,6 +54,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal def node_create_task( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: + """Convert `asyncio.Task(...)` to `asyncio.create_task(...)`""" self.report_change(original_node) maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) if (maybe_name := maybe_name or self._module_name) == self._module_name: @@ -70,7 +69,7 @@ def node_loop_create_task( loop_arg: cst.Arg, other_args: list[cst.Arg], ) -> cst.Call: - """todo: document""" + """Convert `asyncio.Task(..., loop=loop,...)` to `loop.create_task(...)`""" self.report_change(original_node) coroutine_arg = coroutine_arg.with_changes(comma=cst.MaybeSentinel.DEFAULT) loop_attr = loop_arg.value @@ -82,7 +81,9 @@ def node_loop_create_task( return new_call def _find_loop_arg(self, node: cst.Call) -> tuple[Optional[cst.Arg], list[cst.Arg]]: - """dcoment args[:1: bc first arg is coroutine""" + """Find the loop kwarg from a call to `asyncio.Task(...)` + First arg is always the coroutine so we ignore it. + """ loop_arg = None other_args = [] for arg in node.args[1:]: diff --git a/tests/samples/fix_task_instantiation.py b/tests/samples/fix_task_instantiation.py new file mode 100644 index 000000000..ee7e51444 --- /dev/null +++ b/tests/samples/fix_task_instantiation.py @@ -0,0 +1,11 @@ +import asyncio + +async def my_coroutine(): + await asyncio.sleep(1) + print("Task completed") + +async def main(): + task = asyncio.Task(my_coroutine(), name="my task") + await task + +asyncio.run(main())