diff --git a/pyproject.toml b/pyproject.toml index 2a0b667a..38bf8034 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,7 @@ license = {file = "LICENSE"} dependencies = [ "dependency-manager @ git+https://github.com/pixee/python-dependency-manager#egg=dependency-manager", "isort~=5.12.0", - # Temp fix until the next release of libcst - "libcst @ git+https://github.com/Instagram/LibCST.git@03179b55ebe7e916f1722e18e8f0b87c01616d1f", + "libcst~=1.1.0", "pylint~=3.0.0", "PyYAML~=6.0.0", "semgrep~=1.43.0", diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index b316fef0..98f3d658 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -2,7 +2,7 @@ import libcst as cst from libcst import MetadataDependent, matchers from libcst.helpers import get_full_name_for_node -from libcst.metadata import Assignment, ImportAssignment, ScopeProvider +from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider class NameResolutionMixin(MetadataDependent): @@ -60,14 +60,14 @@ def _is_direct_call_from_imported_module( def find_assignments( self, node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator], - ) -> set[Assignment]: + ) -> set[BaseAssignment]: """ Given a MetadataWrapper and a CSTNode with a possible access to it, find all the possible assignments that it refers. """ scope = self.get_metadata(ScopeProvider, node) if node in scope.accesses: # pylint: disable=protected-access - return next(iter(scope.accesses[node]))._Access__assignments + return set(next(iter(scope.accesses[node])).referents) return set() def find_single_assignment( diff --git a/tests/transformations/test_add_imports.py b/tests/transformations/test_add_imports.py deleted file mode 100644 index 9d3738e9..00000000 --- a/tests/transformations/test_add_imports.py +++ /dev/null @@ -1,50 +0,0 @@ -from libcst.codemod import CodemodTest -from libcst.codemod.visitors import AddImportsVisitor -from libcst.codemod.visitors import ImportItem - - -class TestAddImports(CodemodTest): - TRANSFORM = AddImportsVisitor - - def test_add_only_at_top(self): - before = """ - b() - import a - """ - - after = """ - import b - - b() - import a - """ - - self.assertCodemod(before, after, imports=[ImportItem("b", None, None)]) - - def test_may_duplicate_imports(self): - before = """ - a() - import a - """ - - after = """ - import a - - a() - import a - """ - self.assertCodemod(before, after, imports=[ImportItem("a", None, None)]) - - def test_may_duplicate_from_imports(self): - before = """ - y() - from a import x - """ - - after = """ - from a import y - - y() - from a import x - """ - self.assertCodemod(before, after, imports=[ImportItem("a", "y", None)])