From dbfb1fb5932f91aa128baea9788ce6751818183a Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Wed, 13 Mar 2024 11:47:58 -0300 Subject: [PATCH] fix missing self or cls will not change nested functions --- src/core_codemods/fix_missing_self_or_cls.py | 61 +++++++++---------- .../codemods/test_fix_missing_self_or_cls.py | 57 ++++++++++++----- 2 files changed, 72 insertions(+), 46 deletions(-) diff --git a/src/core_codemods/fix_missing_self_or_cls.py b/src/core_codemods/fix_missing_self_or_cls.py index e023ee9e..56b5d4d7 100644 --- a/src/core_codemods/fix_missing_self_or_cls.py +++ b/src/core_codemods/fix_missing_self_or_cls.py @@ -4,51 +4,50 @@ LibcstResultTransformer, LibcstTransformerPipeline, ) -from codemodder.codemods.utils_mixin import NameResolutionMixin +from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin from core_codemods.api import Metadata, ReviewGuidance from core_codemods.api.core_codemod import CoreCodemod -class FixMissingSelfOrClsTransformer(LibcstResultTransformer, NameResolutionMixin): +class FixMissingSelfOrClsTransformer( + LibcstResultTransformer, NameAndAncestorResolutionMixin +): change_description = "Add `self` or `cls` parameter to instance or class method." - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.current_class_name = None - - def visit_ClassDef(self, node: cst.ClassDef) -> bool: - self.current_class_name = node.name.value - return True - - def leave_ClassDef( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef - ) -> cst.ClassDef: - self.current_class_name = None - return updated_node - def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef ) -> cst.FunctionDef: - if self.current_class_name: - if original_node.decorators: - if self.is_staticmethod(original_node): - return updated_node - if self.is_classmethod(original_node): - if self.has_no_args(original_node): - self.report_change(original_node) - return updated_node.with_changes( - params=updated_node.params.with_changes( - params=[cst.Param(name=cst.Name("cls"))] - ) - ) - else: + # TODO: add filter by include or exclude that works for nodes + # that that have different start/end numbers. + + if not self.find_immediate_class_def(original_node): + # If `original_node` is not inside a class, nothing to do. + return original_node + + if self.find_immediate_function_def(original_node): + # If `original_node` is inside a class but also nested within a function/method + # We won't touch it. + return original_node + + if original_node.decorators: + if self.is_staticmethod(original_node): + return updated_node + if self.is_classmethod(original_node): if self.has_no_args(original_node): self.report_change(original_node) return updated_node.with_changes( params=updated_node.params.with_changes( - params=[cst.Param(name=self._pick_arg_name(original_node))] + params=[cst.Param(name=cst.Name("cls"))] ) ) + else: + if self.has_no_args(original_node): + self.report_change(original_node) + return updated_node.with_changes( + params=updated_node.params.with_changes( + params=[cst.Param(name=self._pick_arg_name(original_node))] + ) + ) return updated_node def _pick_arg_name(self, node: cst.FunctionDef) -> cst.Name: @@ -79,7 +78,7 @@ def has_no_args(self, node: cst.FunctionDef) -> bool: FixMissingSelfOrCls = CoreCodemod( metadata=Metadata( name="fix-missing-self-or-cls", - review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, + review_guidance=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, summary="Add Missing Positional Parameter for Instance and Class Methods", references=[], ), diff --git a/tests/codemods/test_fix_missing_self_or_cls.py b/tests/codemods/test_fix_missing_self_or_cls.py index e64a0135..3116c6c1 100644 --- a/tests/codemods/test_fix_missing_self_or_cls.py +++ b/tests/codemods/test_fix_missing_self_or_cls.py @@ -41,6 +41,41 @@ def __init_subclass__(cls): """ self.run_and_assert(tmpdir, input_code, expected, num_changes=4) + def test_change_not_nested(self, tmpdir): + input_code = """ + class A: + def method(): + def inner(): + pass + + @classmethod + def clsmethod(): + def other_inner(): + pass + + def wrapper(): + class B: + def method(): + pass + """ + expected = """ + class A: + def method(self): + def inner(): + pass + + @classmethod + def clsmethod(cls): + def other_inner(): + pass + + def wrapper(): + class B: + def method(): + pass + """ + self.run_and_assert(tmpdir, input_code, expected, num_changes=2) + @pytest.mark.parametrize( "code", [ @@ -56,7 +91,7 @@ def method(self, arg): @classmethod def clsmethod(cls, arg): pass - + @staticmethod def my_static(): pass @@ -68,6 +103,12 @@ def __new__(*args, **kwargs): def __init_subclass__(**kwargs): pass """, + """ + class A(): + def f(self): + def g(): + pass + """, ], ) def test_no_change(self, tmpdir, code): @@ -94,17 +135,3 @@ def kls(**kwargs): pass """ self.run_and_assert(tmpdir, input_code, input_code) - - # def test_exclude_line(self, tmpdir): - # input_code = ( - # expected - # ) = """ - # assert (1, 2) - # """ - # lines_to_exclude = [2] - # self.run_and_assert( - # tmpdir, - # input_code, - # expected, - # lines_to_exclude=lines_to_exclude, - # )