Skip to content

Commit

Permalink
django dunder str codemod can detect if parent class has a dunder str
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Feb 28, 2024
1 parent 435dc45 commit e73b592
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 51 deletions.
10 changes: 10 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ def find_accesses(self, node) -> Collection[Access]:
return scope.accesses[node]
return {}

def class_has_method(self, classdef: cst.ClassDef, method_name: str) -> bool:
"""Check if a given class definition implements a method of name `method_name`."""
for node in classdef.body.body:
match node:
case cst.FunctionDef(
name=cst.Name(value=value)
) if value == method_name:
return True
return False


class AncestorPatternsMixin(MetadataDependent):
METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = (ParentNodeProvider,)
Expand Down
11 changes: 8 additions & 3 deletions src/core_codemods/django_model_without_dunder_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ def leave_ClassDef(
return updated_node.with_changes(body=new_body)

def implements_dunder_str(self, original_node: cst.ClassDef) -> bool:
for node in original_node.body.body:
match node:
case cst.FunctionDef(name=cst.Name(value="__str__")):
"""Check if a ClassDef or its bases implement `__str__`"""
if self.class_has_method(original_node, "__str__"):
return True

for base in original_node.bases:
if maybe_assignment := self.find_single_assignment(base.value):
classdef = maybe_assignment.node
if self.class_has_method(classdef, "__str__"):
return True
return False

Expand Down
64 changes: 16 additions & 48 deletions tests/codemods/test_django_model_without_dunder_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,51 +57,19 @@ def something():
"""
self.run_and_assert(tmpdir, input_code, expected)

# def test_simple_alias(self, tmpdir):
# input_code = """
# from django.dispatch import receiver as rec
#
# @csrf_exempt
# @rec(request_finished)
# def foo():
# pass
# """
# expected = """
# from django.dispatch import receiver as rec
#
# @rec(request_finished)
# @csrf_exempt
# def foo():
# pass
# """
# self.run_and_assert(tmpdir, input_code, expected)
#
# def test_no_receiver(self, tmpdir):
# input_code = """
# @csrf_exempt
# def foo():
# pass
# """
# self.run_and_assert(tmpdir, input_code, input_code)
#
# def test_receiver_but_not_djangos(self, tmpdir):
# input_code = """
# from not_django import receiver
#
# @csrf_exempt
# @receiver(request_finished)
# def foo():
# pass
# """
# self.run_and_assert(tmpdir, input_code, input_code)
#
# def test_receiver_on_top(self, tmpdir):
# input_code = """
# from django.dispatch import receiver
#
# @receiver(request_finished)
# @csrf_exempt
# def foo():
# pass
# """
# self.run_and_assert(tmpdir, input_code, input_code)
def test_model_inherits_dunder_str(self, tmpdir):
input_code = """
from django.db import models
class Custom:
def __str__(self):
pass
class User(Custom, models.Model):
name = models.CharField(max_length=100)
phone = models.IntegerField(blank=True)
def something():
pass
"""
self.run_and_assert(tmpdir, input_code, input_code)

0 comments on commit e73b592

Please sign in to comment.