diff --git a/src/core_codemods/with_threading_lock.py b/src/core_codemods/with_threading_lock.py new file mode 100644 index 000000000..139471361 --- /dev/null +++ b/src/core_codemods/with_threading_lock.py @@ -0,0 +1,51 @@ +import libcst as cst +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SemgrepCodemod + + +class WithThreadingLock(SemgrepCodemod): + NAME = "bad-lock-with-statement" + SUMMARY = "Replace deprecated usage of threading.Lock context manager" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW + DESCRIPTION = "Separates threading lock instantiation and call with `with` statement into two steps." + + @classmethod + def rule(cls): + return """ + rules: + - patterns: + - pattern: | + with $BODY: + ... + - metavariable-pattern: + metavariable: $BODY + patterns: + - pattern: threading.Lock() + - pattern-inside: | + import threading + ... + - focus-metavariable: $BODY + """ + + def leave_With(self, original_node: cst.With, updated_node: cst.With): + if original_node.items and self.node_is_selected(original_node.items[0]): + # TODO: how to avoid name conflicts here? + name = cst.Name(value="lock") + assign = cst.SimpleStatementLine( + body=[ + cst.Assign( + targets=[cst.AssignTarget(target=name)], + value=updated_node.items[0].item, + ) + ] + ) + return cst.FlattenSentinel( + [ + assign, + updated_node.with_changes( + items=[cst.WithItem(name, asname=updated_node.items[0].asname)] + ), + ] + ) + + return original_node diff --git a/tests/codemods/test_with_threading_lock.py b/tests/codemods/test_with_threading_lock.py new file mode 100644 index 000000000..227224829 --- /dev/null +++ b/tests/codemods/test_with_threading_lock.py @@ -0,0 +1,63 @@ +from core_codemods.with_threading_lock import WithThreadingLock +from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest + + +class TestWithThreadingLock(BaseSemgrepCodemodTest): + codemod = WithThreadingLock + + def test_rule_ids(self): + assert self.codemod.name() == "bad-lock-with-statement" + + def test_simple_replacement(self, tmpdir): + input_code = """import threading +with threading.Lock(): + ... +""" + expected = """import threading +lock = threading.Lock() +with lock: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_simple_replacement_with_import(self, tmpdir): + input_code = """from threading import Lock +with Lock(): + ... +""" + expected = """from threading import Lock +lock = Lock() +with lock: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_simple_replacement_with_as(self, tmpdir): + input_code = """import threading +with threading.Lock() as foo: + ... +""" + expected = """import threading +lock = threading.Lock() +with lock as foo: + ... +""" + self.run_and_assert(tmpdir, input_code, expected) + + def test_no_effect_sanity_check(self, tmpdir): + input_code = """import threading +with threading.Lock(): + ... + +with threading_lock(): + ... +""" + expected = """import threading +lock = threading.Lock() +with lock: + ... + +with threading_lock(): + ... +""" + self.run_and_assert(tmpdir, input_code, expected)