diff --git a/integration_tests/test_remove_assertion_in_pytest_raises.py b/integration_tests/test_remove_assertion_in_pytest_raises.py new file mode 100644 index 00000000..9e2e9975 --- /dev/null +++ b/integration_tests/test_remove_assertion_in_pytest_raises.py @@ -0,0 +1,40 @@ +from core_codemods.remove_assertion_in_pytest_raises import ( + RemoveAssertionInPytestRaises, + RemoveAssertionInPytestRaisesTransformer, +) +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestRemoveAssertionInPytestRaises(BaseIntegrationTest): + codemod = RemoveAssertionInPytestRaises + code_path = "tests/samples/remove_assertion_in_pytest_raises.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, + [ + (5, """ assert 1\n"""), + (6, """ assert 2\n"""), + ], + ) + + # fmt: off + expected_diff =( + """--- \n""" + """+++ \n""" + """@@ -3,5 +3,5 @@\n""" + """ def test_foo():\n""" + """ with pytest.raises(ZeroDivisionError):\n""" + """ error = 1/0\n""" + """- assert 1\n""" + """- assert 2\n""" + """+ assert 1\n""" + """+ assert 2\n""" + ) + # fmt: on + + expected_line_change = "4" + change_description = RemoveAssertionInPytestRaisesTransformer.change_description + num_changed_files = 1 + num_changes = 1 diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index e7aedd91..91e19ffb 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -214,6 +214,10 @@ class DocMetadata: importance="Low", guidance_explained="Values compared to empty sequences should be verified in case they are falsy values that are not a sequence.", ), + "remove-assertion-in-pytest-raises": DocMetadata( + importance="Low", + guidance_explained="We believe this change is safe and will not cause any issues.", + ), } diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index 8be5f2d2..8842bbd2 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -47,6 +47,7 @@ from .flask_enable_csrf_protection import FlaskEnableCSRFProtection from .replace_flask_send_file import ReplaceFlaskSendFile from .fix_empty_sequence_comparison import FixEmptySequenceComparison +from .remove_assertion_in_pytest_raises import RemoveAssertionInPytestRaises registry = CodemodCollection( origin="pixee", @@ -100,5 +101,6 @@ FlaskEnableCSRFProtection, ReplaceFlaskSendFile, FixEmptySequenceComparison, + RemoveAssertionInPytestRaises, ], ) diff --git a/src/core_codemods/docs/pixee_python_remove-assertion-in-pytest-raises.md b/src/core_codemods/docs/pixee_python_remove-assertion-in-pytest-raises.md new file mode 100644 index 00000000..b8616179 --- /dev/null +++ b/src/core_codemods/docs/pixee_python_remove-assertion-in-pytest-raises.md @@ -0,0 +1,15 @@ +The context manager object `pytest.raises()` will assert if the code contained within its scope will raise an exception of type ``. The documentation points that the exception must be raised in the last line of its scope and any line afterwards won't be executed. +Including asserts at the end of the scope is a common error. This codemod addresses that by moving them out of the scope. +Our changes look something like this: + +```diff +import pytest + +def test_foo(): + with pytest.raises(ZeroDivisionError): + error = 1/0 +- assert 1 +- assert 2 ++ assert 1 ++ assert 2 +``` diff --git a/src/core_codemods/remove_assertion_in_pytest_raises.py b/src/core_codemods/remove_assertion_in_pytest_raises.py new file mode 100644 index 00000000..a0098b5f --- /dev/null +++ b/src/core_codemods/remove_assertion_in_pytest_raises.py @@ -0,0 +1,156 @@ +from typing import Sequence, Union +import libcst as cst +from codemodder.codemods.base_codemod import Metadata, Reference, ReviewGuidance +from codemodder.codemods.libcst_transformer import ( + LibcstResultTransformer, + LibcstTransformerPipeline, +) +from codemodder.codemods.utils_mixin import NameResolutionMixin +from core_codemods.api.core_codemod import CoreCodemod + + +class RemoveAssertionInPytestRaisesTransformer( + LibcstResultTransformer, NameResolutionMixin +): + change_description = "Moved assertion out of with statement body" + + def _all_pytest_raises(self, node: cst.With): + for item in node.items: + match item: + case cst.WithItem(item=cst.Call() as call): + maybe_call_base_name = self.find_base_name(call) + if ( + not maybe_call_base_name + or maybe_call_base_name != "pytest.raises" + ): + return False + + case _: + return False + return True + + def _build_simple_statement_line(self, node: cst.BaseSmallStatement): + return cst.SimpleStatementLine( + body=[node.with_changes(semicolon=cst.MaybeSentinel.DEFAULT)] + ) + + def _remove_last_asserts_from_suite(self, node: Sequence[cst.BaseSmallStatement]): + assert_position = len(node) + assert_stmts = [] + new_statement_before_asserts = None + for stmt in reversed(node): + match stmt: + case cst.Assert(): + assert_position = assert_position - 1 + assert_stmts.append(self._build_simple_statement_line(stmt)) + case _: + break + if assert_position > 0: + new_statement_before_asserts = node[assert_position - 1].with_changes( + semicolon=cst.MaybeSentinel.DEFAULT + ) + return assert_stmts, assert_position, new_statement_before_asserts + + def _remove_last_asserts_from_IndentedBlock(self, node: cst.IndentedBlock): + assert_position = len(node.body) + assert_stmts = [] + new_statement_before_asserts = None + for simple_stmt in reversed(node.body): + match simple_stmt: + case cst.SimpleStatementLine(body=[*head, cst.Assert()] as body): + assert_position = assert_position - 1 + if head: + sstmts, s_pos, new_stmt = self._remove_last_asserts_from_suite( + body + ) + assert_stmts.extend(sstmts) + if new_stmt: + new_statement_before_asserts = new_stmt + new_statement_before_asserts = simple_stmt.with_changes( + body=[ + *body[: s_pos - 1], + body[s_pos - 1].with_changes( + semicolon=cst.MaybeSentinel.DEFAULT + ), + ] + ) + break + else: + assert_stmts.append(simple_stmt) + if new_statement_before_asserts: + break + case _: + if assert_position > 0: + new_statement_before_asserts = node.body[assert_position - 1] + break + assert_stmts.reverse() + return assert_stmts, assert_position, new_statement_before_asserts + + def leave_With( + self, original_node: cst.With, updated_node: cst.With + ) -> Union[ + cst.BaseStatement, cst.FlattenSentinel[cst.BaseStatement], cst.RemovalSentinel + ]: + # TODO: add filter by include or exclude that works for nodes + # that that have different start/end numbers. + + # Are all items pytest.raises? + if not self._all_pytest_raises(original_node): + return updated_node + + assert_stmts: list[cst.SimpleStatementLine] = [] + assert_position = len(original_node.body.body) + new_statement_before_asserts = None + match original_node.body: + case cst.SimpleStatementSuite(): + ( + assert_stmts, + assert_position, + new_statement_before_asserts, + ) = self._remove_last_asserts_from_suite(original_node.body.body) + assert_stmts.reverse() + case cst.IndentedBlock(): + ( + assert_stmts, + assert_position, + new_statement_before_asserts, + ) = self._remove_last_asserts_from_IndentedBlock(original_node.body) + + if assert_stmts: + # this means all the statements are asserts + if new_statement_before_asserts: + new_with = updated_node.with_changes( + body=updated_node.body.with_changes( + body=[ + *updated_node.body.body[: assert_position - 1], + new_statement_before_asserts, + ] + ) + ) + else: + new_with = updated_node.with_changes( + body=updated_node.body.with_changes( + body=[cst.SimpleStatementLine(body=[cst.Pass()])] + ) + ) + self.report_change(original_node) + return cst.FlattenSentinel([new_with, *assert_stmts]) + + return updated_node + + +RemoveAssertionInPytestRaises = CoreCodemod( + metadata=Metadata( + name="remove-assertion-in-pytest-raises", + summary="Moves assertions out of `pytest.raises` scope", + review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + references=[ + Reference( + url="https://docs.pytest.org/en/7.4.x/reference/reference.html#pytest-raises", + description="", + ), + ], + ), + transformer=LibcstTransformerPipeline(RemoveAssertionInPytestRaisesTransformer), + detector=None, +) diff --git a/tests/codemods/test_remove_assertion_in_pytest_raises.py b/tests/codemods/test_remove_assertion_in_pytest_raises.py new file mode 100644 index 00000000..a60b9abc --- /dev/null +++ b/tests/codemods/test_remove_assertion_in_pytest_raises.py @@ -0,0 +1,201 @@ +from core_codemods.remove_assertion_in_pytest_raises import ( + RemoveAssertionInPytestRaises, +) +from tests.codemods.base_codemod_test import BaseCodemodTest + + +class TestRemoveAssertionInPytestRaises(BaseCodemodTest): + codemod = RemoveAssertionInPytestRaises + + def test_name(self): + assert self.codemod.name == "remove-assertion-in-pytest-raises" + + def test_simple(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert True + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert True + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_simple_alias(self, tmpdir): + input_code = """\ + from pytest import raises as rise + def foo(): + with rise(ZeroDivisionError): + 1/0 + assert True + """ + expected = """\ + from pytest import raises as rise + def foo(): + with rise(ZeroDivisionError): + 1/0 + assert True + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_simple_from_import(self, tmpdir): + input_code = """\ + from pytest import raises + def foo(): + with raises(ZeroDivisionError): + 1/0 + assert True + """ + expected = """\ + from pytest import raises + def foo(): + with raises(ZeroDivisionError): + 1/0 + assert True + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_all_asserts(self, tmpdir): + # this is more of an edge case + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + assert True + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + pass + assert True + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_multiple_raises(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError), pytest.raises(IndexError): + 1/0 + [1,2][3] + assert True + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError), pytest.raises(IndexError): + 1/0 + [1,2][3] + assert True + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_multiple_asserts(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert 1 + assert 2 + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert 1 + assert 2 + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_multiple_asserts_mixed_early(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0; assert 1; assert 2 + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert 1 + assert 2 + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_multiple_asserts_mixed(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert 1; assert 2 + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): + 1/0 + assert 1 + assert 2 + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_simple_suite(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): 1/0; assert True + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): 1/0 + assert True + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_multiple_suite(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): 1/0; assert True; assert False; + """ + expected = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError): 1/0 + assert True + assert False + """ + self.run_and_assert(tmpdir, input_code, expected) + + def test_with_item_not_raises(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError), open('') as file: + 1/0 + assert True + """ + self.run_and_assert(tmpdir, input_code, input_code) + + def test_no_assertion_at_end(self, tmpdir): + input_code = """\ + import pytest + def foo(): + with pytest.raises(ZeroDivisionError), open('') as file: + assert True + 1/0 + """ + self.run_and_assert(tmpdir, input_code, input_code) diff --git a/tests/samples/remove_assertion_in_pytest_raises.py b/tests/samples/remove_assertion_in_pytest_raises.py new file mode 100644 index 00000000..449408c0 --- /dev/null +++ b/tests/samples/remove_assertion_in_pytest_raises.py @@ -0,0 +1,7 @@ +import pytest + +def test_foo(): + with pytest.raises(ZeroDivisionError): + error = 1/0 + assert 1 + assert 2