Skip to content

Commit

Permalink
Remove global state; fix dependency manager
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Sep 15, 2023
1 parent c1c4671 commit f16feed
Show file tree
Hide file tree
Showing 23 changed files with 116 additions and 95 deletions.
1 change: 0 additions & 1 deletion integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from tests.shared import reset_global_state # pylint: disable=unused-import
7 changes: 3 additions & 4 deletions integration_tests/semgrep/test_semgrep.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from codemodder import global_state
from codemodder.semgrep import run as semgrep_run
from codemodder.codemods import SecureRandom, UrlSandbox

Expand Down Expand Up @@ -32,10 +31,10 @@ def _assert_secure_random_results(self, results):
assert location["region"]["endLine"] == 3
assert location["region"]["snippet"]["text"] == "random.random()"

def test_two_codemods(self):
global_state.set_directory("tests/samples/")
def test_two_codemods(self, mocker):
context = mocker.MagicMock(directory="tests/samples")
results_by_path_and_id = semgrep_run(
{"secure-random": SecureRandom, "url-sandbox": UrlSandbox}
context, {"secure-random": SecureRandom, "url-sandbox": UrlSandbox}
)

assert sorted(results_by_path_and_id.keys()) == [
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ mock==5.1.*
pre-commit<4
pytest==7.4.*
pytest-cov~=4.1.0
pytest-mock~=3.11.0
pytest-xdist==3.*
types-mock==5.1.*
15 changes: 5 additions & 10 deletions src/codemodder/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
from codemodder.code_directory import file_line_patterns, match_files
from codemodder.codemods import match_codemods
from codemodder.context import CodemodExecutionContext, ChangeSet
from codemodder.dependency_manager import DependencyManager
from codemodder.dependency_manager import write_dependencies
from codemodder.report.codetf_reporter import report_default
from codemodder.semgrep import run as semgrep_run
from codemodder.sarifs import parse_sarif_files

# Must use from import here to point to latest state
from codemodder import global_state # TODO: should not use global state


def update_code(file_path, new_code):
"""
Expand All @@ -40,8 +37,7 @@ def run_codemods_for_file(
wrapper = cst.MetadataWrapper(source_tree)
codemod = codemod_kls(
CodemodContext(wrapper=wrapper),
# TODO: eventually pass execution context here
# It will be used for things like dependency management
execution_context,
file_context,
)
if not codemod.should_transform:
Expand Down Expand Up @@ -140,7 +136,6 @@ def run(argv, original_args) -> int:
# project directory doesn't exist or can’t be read
return 1

global_state.set_directory(Path(argv.directory))
context = CodemodExecutionContext(Path(argv.directory), argv.dry_run)

codemods_to_run = match_codemods(argv.codemod_include, argv.codemod_exclude)
Expand All @@ -155,7 +150,7 @@ def run(argv, original_args) -> int:
sarif_results = parse_sarif_files(argv.sarif or [])

# run semgrep and gather the results
semgrep_results = semgrep_run(codemods_to_run)
semgrep_results = semgrep_run(context, codemods_to_run)

# merge the results
sarif_results.update(semgrep_results)
Expand All @@ -164,7 +159,7 @@ def run(argv, original_args) -> int:
logger.warning("No sarif results.")

files_to_analyze = match_files(
global_state.DIRECTORY, argv.path_exclude, argv.path_include
context.directory, argv.path_exclude, argv.path_include
)
if not files_to_analyze:
logger.warning("No files matched.")
Expand All @@ -183,7 +178,7 @@ def run(argv, original_args) -> int:

results = compile_results(context, codemods_to_run)

DependencyManager().write(dry_run=context.dry_run)
write_dependencies(context)
elapsed = datetime.datetime.now() - start
elapsed_ms = int(elapsed.total_seconds() * 1000)
report_default(elapsed_ms, argv, original_args, results)
Expand Down
10 changes: 8 additions & 2 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from codemodder.codemods.base_visitor import BaseTransformer
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from codemodder.semgrep import rule_ids_from_yaml_files
from .helpers import Helpers
Expand Down Expand Up @@ -100,8 +101,13 @@ def __init_subclass__(cls):
cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA)
cls.RULE_IDS = rule_ids_from_yaml_files(cls.YAML_FILES)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
_SemgrepCodemod.__init__(self, file_context)
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
_SemgrepCodemod.__init__(self, execution_context, file_context)
BaseTransformer.__init__(
self,
codemod_context,
Expand Down
5 changes: 4 additions & 1 deletion src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
from typing import List, ClassVar

from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from codemodder.semgrep import rule_ids_from_yaml_files

Expand All @@ -26,6 +27,7 @@ class BaseCodemod:
SUMMARY: ClassVar[str] = NotImplemented
IS_SEMGREP = False

execution_context: CodemodExecutionContext
file_context: FileContext

def __init_subclass__(cls, **kwargs):
Expand All @@ -47,7 +49,8 @@ def __init_subclass__(cls, **kwargs):
if not v:
raise NotImplementedError(f"METADATA.{k} should not be None or empty")

def __init__(self, file_context):
def __init__(self, execution_context: CodemodExecutionContext, file_context):
self.execution_context = execution_context
self.file_context = file_context

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/codemodder/codemods/django_debug_flag_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class DjangoDebugFlagOn(SemgrepCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(self, codemod_context: CodemodContext, *args):
Codemod.__init__(self, codemod_context)
SemgrepCodemod.__init__(self, file_context)
SemgrepCodemod.__init__(self, *args)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
# checks if the file we looking is a settings.py file from django's default directory structure
Expand Down
4 changes: 2 additions & 2 deletions src/codemodder/codemods/django_session_cookie_secure_off.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class DjangoSessionCookieSecureOff(SemgrepCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(self, codemod_context: CodemodContext, *args):
Codemod.__init__(self, codemod_context)
SemgrepCodemod.__init__(self, file_context)
SemgrepCodemod.__init__(self, *args)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
if is_django_settings_file(self.file_context.file_path):
Expand Down
12 changes: 9 additions & 3 deletions src/codemodder/codemods/https_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
CodemodMetadata,
ReviewGuidance,
)
from codemodder.codemods.change import Change
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.file_context import FileContext
import libcst as cst
Expand All @@ -36,9 +37,14 @@ class HTTPSConnection(BaseCodemod, Codemod):
"urllib3.connectionpool.HTTPConnectionPool",
}

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
Codemod.__init__(self, codemod_context)
BaseCodemod.__init__(self, file_context)
BaseCodemod.__init__(self, execution_context, file_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include

Expand Down
13 changes: 9 additions & 4 deletions src/codemodder/codemods/order_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.codemods.transformations.clean_imports import (
GatherTopLevelImportBlocks,
OrderImportsBlocksTransform,
)
from codemodder.file_context import FileContext
import libcst as cst
from libcst.codemod import Codemod, CodemodContext
import codemodder.global_state


class OrderImports(BaseCodemod, Codemod):
Expand All @@ -26,9 +26,14 @@ class OrderImports(BaseCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
Codemod.__init__(self, codemod_context)
BaseCodemod.__init__(self, file_context)
BaseCodemod.__init__(self, execution_context, file_context)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include

Expand All @@ -45,7 +50,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
filtered_blocks.append(block)
if filtered_blocks:
order_transformer = OrderImportsBlocksTransform(
codemodder.global_state.DIRECTORY, filtered_blocks
self.execution_context.directory, filtered_blocks
)
result_tree = tree.visit(order_transformer)

Expand Down
3 changes: 1 addition & 2 deletions src/codemodder/codemods/process_creation_sandbox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import libcst as cst
from codemodder.dependency_manager import DependencyManager
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import SemgrepCodemod

Expand Down Expand Up @@ -31,7 +30,7 @@ def rule(cls):

def on_result_found(self, original_node, updated_node):
self.add_needed_import("security", "safe_command")
DependencyManager().add(["security==1.0.1"])
self.execution_context.add_dependency("security==1.0.1")
return self.update_call_target(
updated_node,
"safe_command",
Expand Down
5 changes: 2 additions & 3 deletions src/codemodder/codemods/remove_unnecessary_f_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import libcst.matchers as m
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.api import BaseCodemod
from codemodder.file_context import FileContext


class RemoveUnnecessaryFStr(BaseCodemod, UnnecessaryFormatString):
Expand All @@ -15,9 +14,9 @@ class RemoveUnnecessaryFStr(BaseCodemod, UnnecessaryFormatString):
SUMMARY = "Remove unnecessary f-strings."
DESCRIPTION = UnnecessaryFormatString.DESCRIPTION

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(self, codemod_context: CodemodContext, *codemod_args):
UnnecessaryFormatString.__init__(self, codemod_context)
BaseCodemod.__init__(self, file_context)
BaseCodemod.__init__(self, *codemod_args)

@m.leave(m.FormattedString(parts=(m.FormattedStringText(),)))
def _check_formatted_string(
Expand Down
11 changes: 9 additions & 2 deletions src/codemodder/codemods/remove_unused_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.codemods.transformations.remove_unused_imports import (
RemoveUnusedImportsTransformer,
)
Expand All @@ -25,9 +26,15 @@ class RemoveUnusedImports(BaseCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider, QualifiedNameProvider)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
Codemod.__init__(self, codemod_context)
BaseCodemod.__init__(self, file_context)
BaseCodemod.__init__(self, execution_context, file_context)
# TODO: these should be moved to the base codemod class (as properties)
self.line_exclude = file_context.line_exclude
self.line_include = file_context.line_include

Expand Down
10 changes: 8 additions & 2 deletions src/codemodder/codemods/upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext


Expand All @@ -31,8 +32,13 @@ class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer):
PROTOCOL_ARG_INDEX = 0
PROTOCOL_KWARG_NAME = "protocol"

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
SemgrepCodemod.__init__(self, file_context)
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
SemgrepCodemod.__init__(self, execution_context, file_context)
BaseTransformer.__init__(
self,
codemod_context,
Expand Down
7 changes: 3 additions & 4 deletions src/codemodder/codemods/url_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from codemodder.file_context import FileContext

from libcst.codemod.visitors import AddImportsVisitor, ImportItem
from codemodder.dependency_manager import DependencyManager
from codemodder.change import Change
from codemodder.codemods.base_codemod import (
SemgrepCodemod,
Expand Down Expand Up @@ -39,9 +38,9 @@ class UrlSandbox(SemgrepCodemod, Codemod):

METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider)

def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
def __init__(self, codemod_context: CodemodContext, *args):
Codemod.__init__(self, codemod_context)
SemgrepCodemod.__init__(self, file_context)
SemgrepCodemod.__init__(self, *args)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
# we first gather all the nodes we want to change together with their replacements
Expand All @@ -54,7 +53,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
find_requests_visitor.changes_in_file
)
new_tree = tree.visit(ReplaceNodes(find_requests_visitor.nodes_to_change))
DependencyManager().add(["security==1.0.1"])
self.execution_context.add_dependency("security==1.0.1")
# if it finds any request.get(...), try to remove the imports
if any(
(
Expand Down
5 changes: 5 additions & 0 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ def to_json(self):

class CodemodExecutionContext:
results_by_codemod: dict[str, list[ChangeSet]] = {}
dependencies: set[str]
directory: Path
dry_run: bool = False

def __init__(self, directory, dry_run):
self.directory = directory
self.dry_run = dry_run
self.dependencies = set()
self.results_by_codemod = {}

def add_result(self, codemod_name, change_set):
self.results_by_codemod.setdefault(codemod_name, []).append(change_set)

def add_dependency(self, dependency: str):
self.dependencies.add(dependency)
15 changes: 11 additions & 4 deletions src/codemodder/dependency_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from dependency_manager import DependencyManagerAbstract
from pathlib import Path
from codemodder import global_state

from codemodder.context import CodemodExecutionContext

class DependencyManager(DependencyManagerAbstract):
def get_parent_dir(self):
return Path(global_state.DIRECTORY)

def write_dependencies(execution_context: CodemodExecutionContext):
class DependencyManager(DependencyManagerAbstract):
def get_parent_dir(self):
return Path(execution_context.directory)

dm = DependencyManager()
dm.add(list(execution_context.dependencies))
dm.write(dry_run=execution_context.dry_run)
return dm
Loading

0 comments on commit f16feed

Please sign in to comment.