Skip to content

Commit

Permalink
add name resolution to threading codemod (#71)
Browse files Browse the repository at this point in the history
* add name resolution to threading codemod

* gather all names in global scope

* add no cover to just-in-case lines

* add test cases for multiple locks in a module

* use type-specific names

* keep track of additional locks added

* add unit test
  • Loading branch information
clavedeluna authored Oct 17, 2023
1 parent 8128ee6 commit 8b93bae
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 13 deletions.
33 changes: 33 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from libcst import MetadataDependent, matchers
from libcst.helpers import get_full_name_for_node
from libcst.metadata import Assignment, BaseAssignment, ImportAssignment, ScopeProvider
from libcst.metadata.scope_provider import GlobalScope


class NameResolutionMixin(MetadataDependent):
Expand Down Expand Up @@ -70,6 +71,30 @@ def find_assignments(
return set(next(iter(scope.accesses[node])).referents)
return set()

def find_used_names_in_module(self):
"""
Find all the used names in the scope of a libcst Module.
"""
names = []
scope = self.find_global_scope()
if scope is None:
return [] # pragma: no cover

nodes = [x.node for x in scope.assignments]
for other_nodes in nodes:
visitor = GatherNamesVisitor()
other_nodes.visit(visitor)
names.extend(visitor.names)
return names

def find_global_scope(self):
"""Find the global scope for a libcst Module node."""
scopes = self.context.wrapper.resolve(ScopeProvider).values()
for scope in scopes:
if isinstance(scope, GlobalScope):
return scope
return None # pragma: no cover

def find_single_assignment(
self,
node: Union[cst.Name, cst.Attribute, cst.Call, cst.Subscript, cst.Decorator],
Expand Down Expand Up @@ -112,3 +137,11 @@ def _get_name(node: Union[cst.Import, cst.ImportFrom]) -> str:
if matchers.matches(node, matchers.Import()):
return get_full_name_for_node(node.names[0].name)
return ""


class GatherNamesVisitor(cst.CSTVisitor):
def __init__(self):
self.names = []

def visit_Name(self, node: cst.Name) -> None:
self.names.append(node.value)
36 changes: 33 additions & 3 deletions src/core_codemods/with_threading_lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import libcst as cst
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod
from codemodder.codemods.utils_mixin import NameResolutionMixin


class WithThreadingLock(SemgrepCodemod):
class WithThreadingLock(SemgrepCodemod, NameResolutionMixin):
NAME = "bad-lock-with-statement"
SUMMARY = "Separate Lock Instantiation from `with` Call"
DESCRIPTION = (
Expand Down Expand Up @@ -44,15 +45,35 @@ def rule(cls):
- focus-metavariable: $BODY
"""

def __init__(self, *args):
SemgrepCodemod.__init__(self, *args)
NameResolutionMixin.__init__(self)
self.names_in_module = self.find_used_names_in_module()

def _create_new_variable(self, original_node: cst.With):
"""
Create an appropriately named variable for the new
lock, condition, or semaphore.
Keep track of this addition in case that are other additions.
"""
base_name = _get_node_name(original_node)
value = base_name
counter = 1
while value in self.names_in_module:
value = f"{base_name}_{counter}"
counter += 1

self.names_in_module.append(value)
return cst.Name(value=value)

def leave_With(self, original_node: cst.With, updated_node: cst.With):
# We deliberately restrict ourselves to simple cases where there's only one with clause for now.
# Semgrep appears to be insufficiently expressive to match multiple clauses correctly.
# We should probably just rewrite this codemod using libcst without semgrep.
if len(original_node.items) == 1 and self.node_is_selected(
original_node.items[0]
):
# TODO: how to avoid name conflicts here?
name = cst.Name(value="lock")
name = self._create_new_variable(original_node)
assign = cst.SimpleStatementLine(
body=[
cst.Assign(
Expand All @@ -72,3 +93,12 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With):
)

return original_node


def _get_node_name(original_node: cst.With):
func_call = original_node.items[0].item.func
if isinstance(func_call, cst.Name):
return func_call.value.lower()
if isinstance(func_call, cst.Attribute):
return func_call.attr.value.lower()
return "" # pragma: no cover
6 changes: 4 additions & 2 deletions tests/codemods/base_codemod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
[],
defaultdict(list),
)
wrapper = cst.MetadataWrapper(input_tree)
command_instance = self.codemod(
CodemodContext(),
CodemodContext(wrapper=wrapper),
self.execution_context,
self.file_context,
)
Expand Down Expand Up @@ -83,8 +84,9 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected):
[],
results,
)
wrapper = cst.MetadataWrapper(input_tree)
command_instance = self.codemod(
CodemodContext(),
CodemodContext(wrapper=wrapper),
self.execution_context,
self.file_context,
)
Expand Down
121 changes: 113 additions & 8 deletions tests/codemods/test_with_threading_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_import(self, tmpdir, klass):
...
"""
expected = f"""import threading
lock = threading.{klass}()
with lock:
{klass.lower()} = threading.{klass}()
with {klass.lower()}:
...
"""
self.run_and_assert(tmpdir, input_code, expected)
Expand All @@ -40,8 +40,8 @@ def test_from_import(self, tmpdir, klass):
...
"""
expected = f"""from threading import {klass}
lock = {klass}()
with lock:
{klass.lower()} = {klass}()
with {klass.lower()}:
...
"""
self.run_and_assert(tmpdir, input_code, expected)
Expand All @@ -53,8 +53,8 @@ def test_simple_replacement_with_as(self, tmpdir, klass):
...
"""
expected = f"""import threading
lock = threading.{klass}()
with lock as foo:
{klass.lower()} = threading.{klass}()
with {klass.lower()} as foo:
...
"""
self.run_and_assert(tmpdir, input_code, expected)
Expand All @@ -69,8 +69,8 @@ def test_no_effect_sanity_check(self, tmpdir, klass):
...
"""
expected = f"""import threading
lock = threading.{klass}()
with lock:
{klass.lower()} = threading.{klass}()
with {klass.lower()}:
...
with threading_lock():
Expand All @@ -86,3 +86,108 @@ def test_no_effect_multiple_with_clauses(self, tmpdir, klass):
...
"""
self.run_and_assert(tmpdir, input_code, input_code)


class TestThreadingNameResolution(BaseSemgrepCodemodTest):
codemod = WithThreadingLock

@pytest.mark.parametrize(
"input_code,expected_code",
[
(
"""from threading import Lock
lock = 1
with Lock():
...
""",
"""from threading import Lock
lock = 1
lock_1 = Lock()
with lock_1:
...
""",
),
(
"""from threading import Lock
from something import lock
with Lock():
...
""",
"""from threading import Lock
from something import lock
lock_1 = Lock()
with lock_1:
...
""",
),
(
"""import threading
lock = 1
def f(l):
with threading.Lock():
return [lock_1 for lock_1 in l]
""",
"""import threading
lock = 1
def f(l):
lock_2 = threading.Lock()
with lock_2:
return [lock_1 for lock_1 in l]
""",
),
(
"""import threading
with threading.Lock():
int("1")
with threading.Lock():
print()
var = 1
with threading.Lock():
print()
""",
"""import threading
lock = threading.Lock()
with lock:
int("1")
lock_1 = threading.Lock()
with lock_1:
print()
var = 1
lock_2 = threading.Lock()
with lock_2:
print()
""",
),
(
"""import threading
with threading.Lock():
with threading.Lock():
print()
""",
"""import threading
lock_1 = threading.Lock()
with lock_1:
lock = threading.Lock()
with lock:
print()
""",
),
(
"""import threading
def my_func():
lock = "whatever"
with threading.Lock():
foo()
""",
"""import threading
def my_func():
lock = "whatever"
lock_1 = threading.Lock()
with lock_1:
foo()
""",
),
],
)
def test_name_resolution(self, tmpdir, input_code, expected_code):
self.run_and_assert(tmpdir, input_code, expected_code)

0 comments on commit 8b93bae

Please sign in to comment.