From 6c4edd6d562ae435f49be53e62db13bc97a9ca45 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Mon, 2 Oct 2023 09:14:31 -0300 Subject: [PATCH] Updated dependency on libcst - Updated libcst dependency to include some recent fixes. Should be updated again with a new libcst release. - Added tests to ensure imports are being added at the right place. - Removed a workaround in NameResolutionMixin --- pyproject.toml | 3 +- src/codemodder/codemods/utils_mixin.py | 13 ++--- .../codemods/test_process_creation_sandbox.py | 3 +- tests/codemods/test_url_sandbox.py | 3 +- tests/transformations/test_add_imports.py | 50 +++++++++++++++++++ 5 files changed, 57 insertions(+), 15 deletions(-) create mode 100644 tests/transformations/test_add_imports.py diff --git a/pyproject.toml b/pyproject.toml index 8bfd7f2a..c27ca03a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ license = {file = "LICENSE"} dependencies = [ "dependency-manager @ git+https://github.com/pixee/python-dependency-manager#egg=dependency-manager", "isort~=5.12.0", - "libcst~=1.0.0", + # Temp fix until the next release of libcst + "libcst @ git+https://github.com/Instagram/LibCST.git@03179b55ebe7e916f1722e18e8f0b87c01616d1f", "pylint~=2.17.0", "PyYAML~=6.0.0", "semgrep~=1.41.0", diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index f5595218..b316fef0 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -65,16 +65,9 @@ def find_assignments( Given a MetadataWrapper and a CSTNode with a possible access to it, find all the possible assignments that it refers. """ scope = self.get_metadata(ScopeProvider, node) - # TODO workaround for a bug in libcst - if matchers.matches(node, matchers.Attribute()): - for access in scope.accesses: - if access.node == node: - # pylint: disable=protected-access - return access._Access__assignments - else: - if node in scope.accesses: - # pylint: disable=protected-access - return next(iter(scope.accesses[node]))._Access__assignments + if node in scope.accesses: + # pylint: disable=protected-access + return next(iter(scope.accesses[node]))._Access__assignments return set() def find_single_assignment( diff --git a/tests/codemods/test_process_creation_sandbox.py b/tests/codemods/test_process_creation_sandbox.py index b2ac9686..c53f7f2a 100644 --- a/tests/codemods/test_process_creation_sandbox.py +++ b/tests/codemods/test_process_creation_sandbox.py @@ -102,8 +102,7 @@ def test_multifunctions(self, tmpdir): input_code = """import subprocess subprocess.run("echo 'hi'", shell=True) -subprocess.check_output(["ls", "-l"]) - """ +subprocess.check_output(["ls", "-l"])""" expected = """import subprocess from security import safe_command diff --git a/tests/codemods/test_url_sandbox.py b/tests/codemods/test_url_sandbox.py index e98e692d..2264bbd3 100644 --- a/tests/codemods/test_url_sandbox.py +++ b/tests/codemods/test_url_sandbox.py @@ -84,8 +84,7 @@ def test_requests_multifunctions(self, tmpdir): input_code = """import requests requests.get("www.google.com") -requests.status_codes.codes.FORBIDDEN - """ +requests.status_codes.codes.FORBIDDEN""" expected = """import requests from security import safe_requests diff --git a/tests/transformations/test_add_imports.py b/tests/transformations/test_add_imports.py new file mode 100644 index 00000000..9d3738e9 --- /dev/null +++ b/tests/transformations/test_add_imports.py @@ -0,0 +1,50 @@ +from libcst.codemod import CodemodTest +from libcst.codemod.visitors import AddImportsVisitor +from libcst.codemod.visitors import ImportItem + + +class TestAddImports(CodemodTest): + TRANSFORM = AddImportsVisitor + + def test_add_only_at_top(self): + before = """ + b() + import a + """ + + after = """ + import b + + b() + import a + """ + + self.assertCodemod(before, after, imports=[ImportItem("b", None, None)]) + + def test_may_duplicate_imports(self): + before = """ + a() + import a + """ + + after = """ + import a + + a() + import a + """ + self.assertCodemod(before, after, imports=[ImportItem("a", None, None)]) + + def test_may_duplicate_from_imports(self): + before = """ + y() + from a import x + """ + + after = """ + from a import y + + y() + from a import x + """ + self.assertCodemod(before, after, imports=[ImportItem("a", "y", None)])