Skip to content

Commit

Permalink
Updated dependency on libcst
Browse files Browse the repository at this point in the history
- Updated libcst dependency to include some recent fixes. Should be
  updated again with a new libcst release.
- Added tests to ensure imports are being added at the right place.
- Removed a workaround in NameResolutionMixin
  • Loading branch information
andrecsilva committed Oct 2, 2023
1 parent e2dabb5 commit 6c4edd6
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 15 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ license = {file = "LICENSE"}
dependencies = [
"dependency-manager @ git+https://github.com/pixee/python-dependency-manager#egg=dependency-manager",
"isort~=5.12.0",
"libcst~=1.0.0",
# Temp fix until the next release of libcst
"libcst @ git+https://github.com/Instagram/LibCST.git@03179b55ebe7e916f1722e18e8f0b87c01616d1f",
"pylint~=2.17.0",
"PyYAML~=6.0.0",
"semgrep~=1.41.0",
Expand Down
13 changes: 3 additions & 10 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,9 @@ def find_assignments(
Given a MetadataWrapper and a CSTNode with a possible access to it, find all the possible assignments that it refers.
"""
scope = self.get_metadata(ScopeProvider, node)
# TODO workaround for a bug in libcst
if matchers.matches(node, matchers.Attribute()):
for access in scope.accesses:
if access.node == node:
# pylint: disable=protected-access
return access._Access__assignments
else:
if node in scope.accesses:
# pylint: disable=protected-access
return next(iter(scope.accesses[node]))._Access__assignments
if node in scope.accesses:
# pylint: disable=protected-access
return next(iter(scope.accesses[node]))._Access__assignments
return set()

def find_single_assignment(
Expand Down
3 changes: 1 addition & 2 deletions tests/codemods/test_process_creation_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def test_multifunctions(self, tmpdir):
input_code = """import subprocess
subprocess.run("echo 'hi'", shell=True)
subprocess.check_output(["ls", "-l"])
"""
subprocess.check_output(["ls", "-l"])"""

expected = """import subprocess
from security import safe_command
Expand Down
3 changes: 1 addition & 2 deletions tests/codemods/test_url_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def test_requests_multifunctions(self, tmpdir):
input_code = """import requests
requests.get("www.google.com")
requests.status_codes.codes.FORBIDDEN
"""
requests.status_codes.codes.FORBIDDEN"""

expected = """import requests
from security import safe_requests
Expand Down
50 changes: 50 additions & 0 deletions tests/transformations/test_add_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from libcst.codemod import CodemodTest
from libcst.codemod.visitors import AddImportsVisitor
from libcst.codemod.visitors import ImportItem


class TestAddImports(CodemodTest):
TRANSFORM = AddImportsVisitor

def test_add_only_at_top(self):
before = """
b()
import a
"""

after = """
import b
b()
import a
"""

self.assertCodemod(before, after, imports=[ImportItem("b", None, None)])

def test_may_duplicate_imports(self):
before = """
a()
import a
"""

after = """
import a
a()
import a
"""
self.assertCodemod(before, after, imports=[ImportItem("a", None, None)])

def test_may_duplicate_from_imports(self):
before = """
y()
from a import x
"""

after = """
from a import y
y()
from a import x
"""
self.assertCodemod(before, after, imports=[ImportItem("a", "y", None)])

0 comments on commit 6c4edd6

Please sign in to comment.