From a79452a90975573ba632e276c50d687eb0dc16be Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Thu, 7 Dec 2023 16:20:18 -0500 Subject: [PATCH] Refactor find_base_name utility --- src/codemodder/codemods/utils_mixin.py | 39 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 0d6faa5a..9d80baf0 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -17,24 +17,41 @@ class NameResolutionMixin(MetadataDependent): METADATA_DEPENDENCIES: Tuple[Any, ...] = (ScopeProvider,) + def _find_imported_name(self, node: cst.Name): + match self.find_single_assignment(node): + case ImportAssignment( + name=node.value, + node=( + cst.Import(names=names) | cst.ImportFrom(names=names) + ) as import_node, + ) as assignment: + match names: + case cst.ImportStar(): + return node.value + + for alias in names: + if assignment.name in ( + alias.evaluated_alias, + alias.evaluated_name, + ): + return self.base_name_for_import(import_node, alias) + return node.value + def find_base_name(self, node): """ - Given a node, solve its name to its basest form. For now it can only solve names that are imported. For example, in what follows, the base name for exec.capitalize() is sys.executable.capitalize. + Given a node, resolve its name to its basest form. + + For now it can only solve names that are imported. For example, in what + follows, the base name for exec.capitalize() is sys.executable.capitalize. + + ``` from sys import executable as exec exec.capitalize() + ``` """ match node: case cst.Name(): - maybe_assignment = self.find_single_assignment(node) - if maybe_assignment and isinstance(maybe_assignment, ImportAssignment): - import_node = maybe_assignment.node - for alias in import_node.names: - if maybe_assignment.name in ( - alias.evaluated_alias, - alias.evaluated_name, - ): - return self.base_name_for_import(import_node, alias) - return node.value + return self._find_imported_name(node) case cst.Attribute(): maybe_name = self.find_base_name(node.value)