Skip to content

Commit

Permalink
Implement file-level parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Oct 26, 2023
1 parent 73f5b4a commit 43b88ab
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 88 deletions.
6 changes: 6 additions & 0 deletions src/codemodder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def parse_args(argv, codemod_registry):
default=DEFAULT_INCLUDED_PATHS,
help="Comma-separated set of UNIX glob patterns to include",
)
parser.add_argument(
"--max-workers",
type=int,
default=1,
help="maximum number of workers (threads) to use for processing files in parallel",
)

# At this time we don't do anything with the sarif arg.
parser.add_argument(
Expand Down
108 changes: 66 additions & 42 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
import datetime
import difflib
import logging
Expand Down Expand Up @@ -29,18 +30,14 @@ def update_code(file_path, new_code):


def apply_codemod_to_file(
execution_context: CodemodExecutionContext,
base_directory: Path,
file_context,
codemod_kls: CodemodExecutorWrapper,
source_tree,
dry_run: bool = False,
):
name = codemod_kls.id
wrapper = cst.MetadataWrapper(source_tree)
codemod = codemod_kls(
CodemodContext(wrapper=wrapper),
execution_context,
file_context,
)
codemod = codemod_kls(CodemodContext(wrapper=wrapper), file_context)
if not codemod.should_transform:
return False

Expand All @@ -57,59 +54,86 @@ def apply_codemod_to_file(
)

change_set = ChangeSet(
str(file_context.file_path.relative_to(execution_context.directory)),
str(file_context.file_path.relative_to(base_directory)),
diff,
changes=file_context.codemod_changes,
)
execution_context.add_result(name, change_set)
file_context.add_result(change_set)

if not execution_context.dry_run:
if not dry_run:
update_code(file_context.file_path, output_tree.code)

return True


def process_file(
idx: int,
file_path: Path,
base_directory: Path,
codemod,
sarif,
cli_args,
): # pylint: disable=too-many-arguments
logger.debug("scanning file %s", file_path)
if idx and idx % 100 == 0:
logger.info("scanned %s files...", idx) # pragma: no cover

line_exclude = file_line_patterns(file_path, cli_args.path_exclude)
line_include = file_line_patterns(file_path, cli_args.path_include)
sarif_for_file = sarif.get(str(file_path)) or {}

file_context = FileContext(
base_directory,
file_path,
line_exclude,
line_include,
sarif_for_file,
)

try:
with open(file_path, "r", encoding="utf-8") as f:
source_tree = cst.parse_module(f.read())
except Exception:
file_context.add_failure(file_path)
logger.exception("error parsing file %s", file_path)
return file_context

apply_codemod_to_file(
base_directory,
file_context,
codemod,
source_tree,
cli_args.dry_run,
)

return file_context


def analyze_files(
execution_context: CodemodExecutionContext,
files_to_analyze,
codemod,
sarif,
cli_args,
):
# TODO: parallelize this loop
for idx, file_path in enumerate(files_to_analyze):
logger.debug("scanning file %s", file_path)
if idx and idx % 100 == 0:
logger.info("scanned %s files...", idx) # pragma: no cover

try:
with open(file_path, "r", encoding="utf-8") as f:
source_tree = cst.parse_module(f.read())
except Exception:
execution_context.add_failure(codemod.id, file_path)
logger.exception("error parsing file %s", file_path)
continue

line_exclude = file_line_patterns(file_path, cli_args.path_exclude)
line_include = file_line_patterns(file_path, cli_args.path_include)
sarif_for_file = sarif.get(str(file_path)) or {}

# NOTE: file context will become more important if/when we parallelize this loop
file_context = FileContext(
file_path,
line_exclude,
line_include,
sarif_for_file,
with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor:
logger.debug(
"using executor with %s threads",
cli_args.max_workers,
)

apply_codemod_to_file(
execution_context,
file_context,
codemod,
source_tree,
results = executor.map(
lambda args: process_file(
args[0],
args[1],
execution_context.directory,
codemod,
sarif,
cli_args,
),
enumerate(files_to_analyze),
)

execution_context.add_dependencies(codemod.id, file_context.dependencies)
executor.shutdown(wait=True)
execution_context.process_results(codemod.id, results)


def run(original_args) -> int:
Expand Down
21 changes: 5 additions & 16 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from codemodder.codemods.base_visitor import BaseTransformer
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from .helpers import Helpers

Expand Down Expand Up @@ -84,13 +83,8 @@ class BaseCodemod(
BaseTransformer,
Helpers,
):
def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
_BaseCodemod.__init__(self, execution_context, file_context)
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
_BaseCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, [])

def report_change(self, original_node):
Expand All @@ -112,14 +106,9 @@ def __init_subclass__(cls):
super().__init_subclass__()
cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA)

def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
BaseCodemod.__init__(self, codemod_context, execution_context, file_context)
_SemgrepCodemod.__init__(self, execution_context, file_context)
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
BaseCodemod.__init__(self, codemod_context, file_context)
_SemgrepCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, self._results)

def _new_or_updated_node(self, original_node, updated_node):
Expand Down
6 changes: 1 addition & 5 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from libcst._position import CodeRange

from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.dependency import Dependency
from codemodder.file_context import FileContext
from codemodder.semgrep import run as semgrep_run
Expand Down Expand Up @@ -46,12 +45,9 @@ class BaseCodemod:
SUMMARY: ClassVar[str] = NotImplemented
is_semgrep: bool = False
adds_dependency: bool = False

execution_context: CodemodExecutionContext
file_context: FileContext

def __init__(self, execution_context: CodemodExecutionContext, file_context):
self.execution_context = execution_context
def __init__(self, file_context: FileContext):
self.file_context = file_context

@classmethod
Expand Down
22 changes: 15 additions & 7 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from pathlib import Path
import itertools
from textwrap import indent
from typing import List, Iterator

from codemodder.change import ChangeSet
from codemodder.dependency import Dependency
from codemodder.dependency_manager import DependencyManager
from codemodder.executor import CodemodExecutorWrapper
from codemodder.file_context import FileContext
from codemodder.logging import logger, log_list
from codemodder.registry import CodemodRegistry
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -52,16 +54,16 @@ def __init__(
self.registry = registry
self.repo_manager = repo_manager

def add_result(self, codemod_name, change_set):
self._results_by_codemod.setdefault(codemod_name, []).append(change_set)
def add_results(self, codemod_name: str, change_sets: List[ChangeSet]):
self._results_by_codemod.setdefault(codemod_name, []).extend(change_sets)

def add_failure(self, codemod_name, file_path):
self._failures_by_codemod.setdefault(codemod_name, []).append(file_path)
def add_failures(self, codemod_name: str, failed_files: List[Path]):
self._failures_by_codemod.setdefault(codemod_name, []).extend(failed_files)

def add_dependencies(self, codemod_id: str, dependencies: set[Dependency]):
self.dependencies.setdefault(codemod_id, set()).update(dependencies)

def get_results(self, codemod_name):
def get_results(self, codemod_name: str):
return self._results_by_codemod.get(codemod_name, [])

def get_changed_files(self):
Expand All @@ -71,7 +73,7 @@ def get_changed_files(self):
for change_set in changes
]

def get_failures(self, codemod_name):
def get_failures(self, codemod_name: str):
return self._failures_by_codemod.get(codemod_name, [])

def get_failed_files(self):
Expand All @@ -96,7 +98,7 @@ def process_dependencies(self, codemod_id: str):

dm.add(list(dependencies))
if (changeset := dm.write(self.dry_run)) is not None:
self.add_result(codemod_id, changeset)
self.add_results(codemod_id, [changeset])

def add_description(self, codemod: CodemodExecutorWrapper):
description = codemod.description
Expand All @@ -105,6 +107,12 @@ def add_description(self, codemod: CodemodExecutorWrapper):

return description

def process_results(self, codemod_id: str, results: Iterator[FileContext]):
for file_context in results:
self.add_results(codemod_id, file_context.results)
self.add_failures(codemod_id, file_context.failures)
self.add_dependencies(codemod_id, file_context.dependencies)

def compile_results(self, codemods: list[CodemodExecutorWrapper]):
results = []
for codemod in codemods:
Expand Down
13 changes: 11 additions & 2 deletions src/codemodder/file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@
from pathlib import Path
from typing import Dict, List

from codemodder.change import Change
from codemodder.change import Change, ChangeSet
from codemodder.dependency import Dependency


@dataclass
class FileContext:
class FileContext: # pylint: disable=too-many-instance-attributes
"""
Extra context for running codemods on a given file based on the cli parameters.
"""

base_directory: Path
file_path: Path
line_exclude: List[int] = field(default_factory=list)
line_include: List[int] = field(default_factory=list)
results_by_id: Dict = field(default_factory=dict)
dependencies: set[Dependency] = field(default_factory=set)
codemod_changes: List[Change] = field(default_factory=list)
results: List[ChangeSet] = field(default_factory=list)
failures: List[Path] = field(default_factory=list)

def add_dependency(self, dependency: Dependency):
self.dependencies.add(dependency)

def add_result(self, result: ChangeSet):
self.results.append(result)

def add_failure(self, filename: Path):
self.failures.append(filename)
3 changes: 2 additions & 1 deletion src/core_codemods/order_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
filtered_blocks.append(block)
if filtered_blocks:
order_transformer = OrderImportsBlocksTransform(
self.execution_context.directory, filtered_blocks
self.file_context.base_directory,
filtered_blocks,
)
result_tree = tree.visit(order_transformer)

Expand Down
4 changes: 1 addition & 3 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext

parameter_token = "?"
Expand Down Expand Up @@ -61,14 +60,13 @@ class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod):
def __init__(
self,
context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
*codemod_args,
) -> None:
self.changed_nodes: dict[
cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel
] = {}
BaseCodemod.__init__(self, execution_context, file_context, *codemod_args)
BaseCodemod.__init__(self, file_context, *codemod_args)
UtilsMixin.__init__(self, context, {})
Codemod.__init__(self, context)

Expand Down
10 changes: 2 additions & 8 deletions src/core_codemods/upgrade_sslcontext_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
ReviewGuidance,
)
from codemodder.change import Change
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext


Expand Down Expand Up @@ -43,13 +42,8 @@ class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer):
PROTOCOL_ARG_INDEX = 0
PROTOCOL_KWARG_NAME = "protocol"

def __init__(
self,
codemod_context: CodemodContext,
execution_context: CodemodExecutionContext,
file_context: FileContext,
):
SemgrepCodemod.__init__(self, execution_context, file_context)
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
SemgrepCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, self._results)

# TODO: apply unused import remover
Expand Down
Loading

0 comments on commit 43b88ab

Please sign in to comment.