diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index ee64e699..4715a1b4 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -281,6 +281,16 @@ def find_accesses(self, node) -> Collection[Access]: return scope.accesses[node] return {} + def class_has_method(self, classdef: cst.ClassDef, method_name: str) -> bool: + """Check if a given class definition implements a method of name `method_name`.""" + for node in classdef.body.body: + match node: + case cst.FunctionDef( + name=cst.Name(value=value) + ) if value == method_name: + return True + return False + class AncestorPatternsMixin(MetadataDependent): METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (ParentNodeProvider,) diff --git a/src/core_codemods/django_model_without_dunder_str.py b/src/core_codemods/django_model_without_dunder_str.py index 983b4223..120bfff3 100644 --- a/src/core_codemods/django_model_without_dunder_str.py +++ b/src/core_codemods/django_model_without_dunder_str.py @@ -43,9 +43,14 @@ def leave_ClassDef( return updated_node.with_changes(body=new_body) def implements_dunder_str(self, original_node: cst.ClassDef) -> bool: - for node in original_node.body.body: - match node: - case cst.FunctionDef(name=cst.Name(value="__str__")): + """Check if a ClassDef or its bases implement `__str__`""" + if self.class_has_method(original_node, "__str__"): + return True + + for base in original_node.bases: + if maybe_assignment := self.find_single_assignment(base.value): + classdef = maybe_assignment.node + if self.class_has_method(classdef, "__str__"): return True return False diff --git a/tests/codemods/test_django_model_without_dunder_str.py b/tests/codemods/test_django_model_without_dunder_str.py index 73839318..14505491 100644 --- a/tests/codemods/test_django_model_without_dunder_str.py +++ b/tests/codemods/test_django_model_without_dunder_str.py @@ -57,51 +57,19 @@ def something(): """ self.run_and_assert(tmpdir, input_code, expected) - # def test_simple_alias(self, tmpdir): - # input_code = """ - # from django.dispatch import receiver as rec - # - # @csrf_exempt - # @rec(request_finished) - # def foo(): - # pass - # """ - # expected = """ - # from django.dispatch import receiver as rec - # - # @rec(request_finished) - # @csrf_exempt - # def foo(): - # pass - # """ - # self.run_and_assert(tmpdir, input_code, expected) - # - # def test_no_receiver(self, tmpdir): - # input_code = """ - # @csrf_exempt - # def foo(): - # pass - # """ - # self.run_and_assert(tmpdir, input_code, input_code) - # - # def test_receiver_but_not_djangos(self, tmpdir): - # input_code = """ - # from not_django import receiver - # - # @csrf_exempt - # @receiver(request_finished) - # def foo(): - # pass - # """ - # self.run_and_assert(tmpdir, input_code, input_code) - # - # def test_receiver_on_top(self, tmpdir): - # input_code = """ - # from django.dispatch import receiver - # - # @receiver(request_finished) - # @csrf_exempt - # def foo(): - # pass - # """ - # self.run_and_assert(tmpdir, input_code, input_code) + def test_model_inherits_dunder_str(self, tmpdir): + input_code = """ + from django.db import models + + class Custom: + def __str__(self): + pass + + class User(Custom, models.Model): + name = models.CharField(max_length=100) + phone = models.IntegerField(blank=True) + + def something(): + pass + """ + self.run_and_assert(tmpdir, input_code, input_code)