diff --git a/integration_tests/test_literal_or_new_object_identity.py b/integration_tests/test_literal_or_new_object_identity.py new file mode 100644 index 00000000..7e5bccbc --- /dev/null +++ b/integration_tests/test_literal_or_new_object_identity.py @@ -0,0 +1,32 @@ +from core_codemods.literal_or_new_object_identity import LiteralOrNewObjectIdentity +from integration_tests.base_test import ( + BaseIntegrationTest, + original_and_expected_from_code_path, +) + + +class TestLiteralOrNewObjectIdentity(BaseIntegrationTest): + codemod = LiteralOrNewObjectIdentity + code_path = "tests/samples/literal_or_new_object_identity.py" + original_code, expected_new_code = original_and_expected_from_code_path( + code_path, + [ + (1, """ return l == [1,2,3]\n"""), + ], + ) + + # fmt: off + expected_diff =( + """--- \n""" + """+++ \n""" + """@@ -1,2 +1,2 @@\n""" + """ def foo(l):\n""" + """- return l is [1,2,3]\n""" + """+ return l == [1,2,3]\n""" + + ) + # fmt: on + + expected_line_change = "2" + change_description = LiteralOrNewObjectIdentity.CHANGE_DESCRIPTION + num_changed_files = 1 diff --git a/src/codemodder/scripts/generate_docs.py b/src/codemodder/scripts/generate_docs.py index 1786b03e..e2e77fb6 100644 --- a/src/codemodder/scripts/generate_docs.py +++ b/src/codemodder/scripts/generate_docs.py @@ -174,6 +174,10 @@ class DocMetadata: importance="Low", guidance_explained="Removing future imports is safe and will not cause any issues.", ), + "literal-or-new-object-identity": DocMetadata( + importance="Low", + guidance_explained="Since literals and new objects have their own identities, comparisons against them using `is` operators are most likely a bug and thus we deem the change safe.", + ), } diff --git a/src/core_codemods/__init__.py b/src/core_codemods/__init__.py index b371f659..a4e1cdba 100644 --- a/src/core_codemods/__init__.py +++ b/src/core_codemods/__init__.py @@ -37,6 +37,7 @@ from .numpy_nan_equality import NumpyNanEquality from .sql_parameterization import SQLQueryParameterization from .exception_without_raise import ExceptionWithoutRaise +from .literal_or_new_object_identity import LiteralOrNewObjectIdentity registry = CodemodCollection( origin="pixee", @@ -80,5 +81,6 @@ DjangoJsonResponseType, FlaskJsonResponseType, ExceptionWithoutRaise, + LiteralOrNewObjectIdentity, ], ) diff --git a/src/core_codemods/docs/pixee_python_literal-or-new-object-identity.md b/src/core_codemods/docs/pixee_python_literal-or-new-object-identity.md new file mode 100644 index 00000000..f634c3c4 --- /dev/null +++ b/src/core_codemods/docs/pixee_python_literal-or-new-object-identity.md @@ -0,0 +1,9 @@ +The `is` and `is not` operator will only return `True` when the expression have the same `id`. In other words, `a is b` is equivalent to `id(a) == id(b)`. New objects and literals have their own identities and thus shouldn't be compared with using the `is` or `is not` operators. + +Our changes look something like this: + +```diff +def foo(l): +- return l is [1,2,3] ++ return l == [1,2,3] +``` diff --git a/src/core_codemods/literal_or_new_object_identity.py b/src/core_codemods/literal_or_new_object_identity.py new file mode 100644 index 00000000..fe98b157 --- /dev/null +++ b/src/core_codemods/literal_or_new_object_identity.py @@ -0,0 +1,63 @@ +import libcst as cst +from codemodder.codemods.api import BaseCodemod +from codemodder.codemods.base_codemod import ReviewGuidance + +from codemodder.codemods.utils_mixin import NameResolutionMixin + + +class LiteralOrNewObjectIdentity(BaseCodemod, NameResolutionMixin): + NAME = "literal-or-new-object-identity" + SUMMARY = "Replaces is operator with == for literal or new object comparisons" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_WITHOUT_REVIEW + DESCRIPTION = SUMMARY + REFERENCES = [ + { + "url": "https://docs.python.org/3/library/stdtypes.html#comparisons", + "description": "", + }, + ] + CHANGE_DESCRIPTION = "Replaces is operator with ==" + + def _is_object_creation_or_literal(self, node: cst.BaseExpression): + match node: + case cst.List() | cst.Dict() | cst.Tuple() | cst.Set() | cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString() | cst.ConcatenatedString() | cst.FormattedString(): + return True + case cst.Call(func=cst.Name() as name): + return self.is_builtin_function(node) and name.value in ( + "dict", + "list", + "tuple", + "set", + ) + return False + + def leave_Comparison( + self, original_node: cst.Comparison, updated_node: cst.Comparison + ) -> cst.BaseExpression: + if self.filter_by_path_includes_or_excludes(self.node_position(original_node)): + match original_node: + case cst.Comparison( + left=left, comparisons=[cst.ComparisonTarget() as target] + ): + if isinstance(target.operator, cst.Is | cst.IsNot): + right = target.comparator + if self._is_object_creation_or_literal( + left + ) or self._is_object_creation_or_literal(right): + self.report_change(original_node) + if isinstance(target.operator, cst.Is): + return original_node.with_deep_changes( + target, + operator=cst.Equal( + whitespace_before=target.operator.whitespace_before, + whitespace_after=target.operator.whitespace_after, + ), + ) + return original_node.with_deep_changes( + target, + operator=cst.NotEqual( + whitespace_before=target.operator.whitespace_before, + whitespace_after=target.operator.whitespace_after, + ), + ) + return updated_node diff --git a/tests/codemods/test_literal_or_new_object_identity.py b/tests/codemods/test_literal_or_new_object_identity.py new file mode 100644 index 00000000..72f53cca --- /dev/null +++ b/tests/codemods/test_literal_or_new_object_identity.py @@ -0,0 +1,184 @@ +from core_codemods.literal_or_new_object_identity import LiteralOrNewObjectIdentity +from tests.codemods.base_codemod_test import BaseCodemodTest +from textwrap import dedent + + +class TestLiteralOrNewObjectIdentity(BaseCodemodTest): + codemod = LiteralOrNewObjectIdentity + + def test_name(self): + assert self.codemod.name() == "literal-or-new-object-identity" + + def test_list(self, tmpdir): + input_code = """\ + l is [1,2,3] + """ + expected = """\ + l == [1,2,3] + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_list_lhs(self, tmpdir): + input_code = """\ + [1,2,3] is l + """ + expected = """\ + [1,2,3] == l + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_list_function(self, tmpdir): + input_code = """\ + l is list({1,2,3}) + """ + expected = """\ + l == list({1,2,3}) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_dict(self, tmpdir): + input_code = """\ + l is {1:2} + """ + expected = """\ + l == {1:2} + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_dict_function(self, tmpdir): + input_code = """\ + l is dict({1,2,3}) + """ + expected = """\ + l == dict({1,2,3}) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_tuple(self, tmpdir): + input_code = """\ + l is (1,2,3) + """ + expected = """\ + l == (1,2,3) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_tuple_function(self, tmpdir): + input_code = """\ + l is tuple({1,2,3}) + """ + expected = """\ + l == tuple({1,2,3}) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_set(self, tmpdir): + input_code = """\ + l is {1,2,3} + """ + expected = """\ + l == {1,2,3} + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_set_function(self, tmpdir): + input_code = """\ + l is set([1,2,3]) + """ + expected = """\ + l == set([1,2,3]) + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_int(self, tmpdir): + input_code = """\ + l is 1 + """ + expected = """\ + l == 1 + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_float(self, tmpdir): + input_code = """\ + l is 1.0 + """ + expected = """\ + l == 1.0 + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_imaginary(self, tmpdir): + input_code = """\ + l is 1j + """ + expected = """\ + l == 1j + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_str(self, tmpdir): + input_code = """\ + l is '1' + """ + expected = """\ + l == '1' + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_fstr(self, tmpdir): + input_code = """\ + l is f'1' + """ + expected = """\ + l == f'1' + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_concatenated_str(self, tmpdir): + input_code = """\ + l is '1' ',2' + """ + expected = """\ + l == '1' ',2' + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_negative(self, tmpdir): + input_code = """\ + l is not [1,2,3] + """ + expected = """\ + l != [1,2,3] + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(expected)) + assert len(self.file_context.codemod_changes) == 1 + + def test_do_nothing(self, tmpdir): + input_code = """\ + l == [1,2,3] + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 + + def test_do_nothing_negative(self, tmpdir): + input_code = """\ + l != [1,2,3] + """ + self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code)) + assert len(self.file_context.codemod_changes) == 0 diff --git a/tests/samples/literal_or_new_object_identity.py b/tests/samples/literal_or_new_object_identity.py new file mode 100644 index 00000000..7070b457 --- /dev/null +++ b/tests/samples/literal_or_new_object_identity.py @@ -0,0 +1,2 @@ +def foo(l): + return l is [1,2,3]