diff --git a/src/codemodder/cli.py b/src/codemodder/cli.py index dddb6841..20b18a3f 100644 --- a/src/codemodder/cli.py +++ b/src/codemodder/cli.py @@ -158,6 +158,12 @@ def parse_args(argv, codemod_registry): default=DEFAULT_INCLUDED_PATHS, help="Comma-separated set of UNIX glob patterns to include", ) + parser.add_argument( + "--max-workers", + type=int, + default=1, + help="maximum number of workers (threads) to use for processing files in parallel", + ) # At this time we don't do anything with the sarif arg. parser.add_argument( diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 00852a3f..a758e8bf 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor import datetime import difflib import logging @@ -29,18 +30,14 @@ def update_code(file_path, new_code): def apply_codemod_to_file( - execution_context: CodemodExecutionContext, + base_directory: Path, file_context, codemod_kls: CodemodExecutorWrapper, source_tree, + dry_run: bool = False, ): - name = codemod_kls.id wrapper = cst.MetadataWrapper(source_tree) - codemod = codemod_kls( - CodemodContext(wrapper=wrapper), - execution_context, - file_context, - ) + codemod = codemod_kls(CodemodContext(wrapper=wrapper), file_context) if not codemod.should_transform: return False @@ -57,18 +54,61 @@ def apply_codemod_to_file( ) change_set = ChangeSet( - str(file_context.file_path.relative_to(execution_context.directory)), + str(file_context.file_path.relative_to(base_directory)), diff, changes=file_context.codemod_changes, ) - execution_context.add_result(name, change_set) + file_context.add_result(change_set) - if not execution_context.dry_run: + if not dry_run: update_code(file_context.file_path, output_tree.code) return True +def process_file( + idx: int, + file_path: Path, + base_directory: Path, + codemod, + sarif, + cli_args, +): # pylint: disable=too-many-arguments + logger.debug("scanning file %s", file_path) + if idx and idx % 100 == 0: + logger.info("scanned %s files...", idx) # pragma: no cover + + line_exclude = file_line_patterns(file_path, cli_args.path_exclude) + line_include = file_line_patterns(file_path, cli_args.path_include) + sarif_for_file = sarif.get(str(file_path)) or {} + + file_context = FileContext( + base_directory, + file_path, + line_exclude, + line_include, + sarif_for_file, + ) + + try: + with open(file_path, "r", encoding="utf-8") as f: + source_tree = cst.parse_module(f.read()) + except Exception: + file_context.add_failure(file_path) + logger.exception("error parsing file %s", file_path) + return file_context + + apply_codemod_to_file( + base_directory, + file_context, + codemod, + source_tree, + cli_args.dry_run, + ) + + return file_context + + def analyze_files( execution_context: CodemodExecutionContext, files_to_analyze, @@ -76,40 +116,24 @@ def analyze_files( sarif, cli_args, ): - # TODO: parallelize this loop - for idx, file_path in enumerate(files_to_analyze): - logger.debug("scanning file %s", file_path) - if idx and idx % 100 == 0: - logger.info("scanned %s files...", idx) # pragma: no cover - - try: - with open(file_path, "r", encoding="utf-8") as f: - source_tree = cst.parse_module(f.read()) - except Exception: - execution_context.add_failure(codemod.id, file_path) - logger.exception("error parsing file %s", file_path) - continue - - line_exclude = file_line_patterns(file_path, cli_args.path_exclude) - line_include = file_line_patterns(file_path, cli_args.path_include) - sarif_for_file = sarif.get(str(file_path)) or {} - - # NOTE: file context will become more important if/when we parallelize this loop - file_context = FileContext( - file_path, - line_exclude, - line_include, - sarif_for_file, + with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor: + logger.debug( + "using executor with %s threads", + cli_args.max_workers, ) - - apply_codemod_to_file( - execution_context, - file_context, - codemod, - source_tree, + results = executor.map( + lambda args: process_file( + args[0], + args[1], + execution_context.directory, + codemod, + sarif, + cli_args, + ), + enumerate(files_to_analyze), ) - - execution_context.add_dependencies(codemod.id, file_context.dependencies) + executor.shutdown(wait=True) + execution_context.process_results(codemod.id, results) def run(original_args) -> int: diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py index c1e04d4e..1d141bc9 100644 --- a/src/codemodder/codemods/api/__init__.py +++ b/src/codemodder/codemods/api/__init__.py @@ -16,7 +16,6 @@ from codemodder.codemods.base_visitor import BaseTransformer from codemodder.change import Change -from codemodder.context import CodemodExecutionContext from codemodder.file_context import FileContext from .helpers import Helpers @@ -84,13 +83,8 @@ class BaseCodemod( BaseTransformer, Helpers, ): - def __init__( - self, - codemod_context: CodemodContext, - execution_context: CodemodExecutionContext, - file_context: FileContext, - ): - _BaseCodemod.__init__(self, execution_context, file_context) + def __init__(self, codemod_context: CodemodContext, file_context: FileContext): + _BaseCodemod.__init__(self, file_context) BaseTransformer.__init__(self, codemod_context, []) def report_change(self, original_node): @@ -112,14 +106,9 @@ def __init_subclass__(cls): super().__init_subclass__() cls.YAML_FILES = _create_temp_yaml_file(cls, cls.METADATA) - def __init__( - self, - codemod_context: CodemodContext, - execution_context: CodemodExecutionContext, - file_context: FileContext, - ): - BaseCodemod.__init__(self, codemod_context, execution_context, file_context) - _SemgrepCodemod.__init__(self, execution_context, file_context) + def __init__(self, codemod_context: CodemodContext, file_context: FileContext): + BaseCodemod.__init__(self, codemod_context, file_context) + _SemgrepCodemod.__init__(self, file_context) BaseTransformer.__init__(self, codemod_context, self._results) def _new_or_updated_node(self, original_node, updated_node): diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index a504a6f9..f1237031 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -5,7 +5,6 @@ from libcst._position import CodeRange from codemodder.change import Change -from codemodder.context import CodemodExecutionContext from codemodder.dependency import Dependency from codemodder.file_context import FileContext from codemodder.semgrep import run as semgrep_run @@ -46,12 +45,9 @@ class BaseCodemod: SUMMARY: ClassVar[str] = NotImplemented is_semgrep: bool = False adds_dependency: bool = False - - execution_context: CodemodExecutionContext file_context: FileContext - def __init__(self, execution_context: CodemodExecutionContext, file_context): - self.execution_context = execution_context + def __init__(self, file_context: FileContext): self.file_context = file_context @classmethod diff --git a/src/codemodder/context.py b/src/codemodder/context.py index 3f97f89a..44984fd7 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -2,11 +2,13 @@ from pathlib import Path import itertools from textwrap import indent +from typing import List, Iterator from codemodder.change import ChangeSet from codemodder.dependency import Dependency from codemodder.dependency_manager import DependencyManager from codemodder.executor import CodemodExecutorWrapper +from codemodder.file_context import FileContext from codemodder.logging import logger, log_list from codemodder.registry import CodemodRegistry from codemodder.project_analysis.python_repo_manager import PythonRepoManager @@ -52,16 +54,16 @@ def __init__( self.registry = registry self.repo_manager = repo_manager - def add_result(self, codemod_name, change_set): - self._results_by_codemod.setdefault(codemod_name, []).append(change_set) + def add_results(self, codemod_name: str, change_sets: List[ChangeSet]): + self._results_by_codemod.setdefault(codemod_name, []).extend(change_sets) - def add_failure(self, codemod_name, file_path): - self._failures_by_codemod.setdefault(codemod_name, []).append(file_path) + def add_failures(self, codemod_name: str, failed_files: List[Path]): + self._failures_by_codemod.setdefault(codemod_name, []).extend(failed_files) def add_dependencies(self, codemod_id: str, dependencies: set[Dependency]): self.dependencies.setdefault(codemod_id, set()).update(dependencies) - def get_results(self, codemod_name): + def get_results(self, codemod_name: str): return self._results_by_codemod.get(codemod_name, []) def get_changed_files(self): @@ -71,7 +73,7 @@ def get_changed_files(self): for change_set in changes ] - def get_failures(self, codemod_name): + def get_failures(self, codemod_name: str): return self._failures_by_codemod.get(codemod_name, []) def get_failed_files(self): @@ -96,7 +98,7 @@ def process_dependencies(self, codemod_id: str): dm.add(list(dependencies)) if (changeset := dm.write(self.dry_run)) is not None: - self.add_result(codemod_id, changeset) + self.add_results(codemod_id, [changeset]) def add_description(self, codemod: CodemodExecutorWrapper): description = codemod.description @@ -105,6 +107,12 @@ def add_description(self, codemod: CodemodExecutorWrapper): return description + def process_results(self, codemod_id: str, results: Iterator[FileContext]): + for file_context in results: + self.add_results(codemod_id, file_context.results) + self.add_failures(codemod_id, file_context.failures) + self.add_dependencies(codemod_id, file_context.dependencies) + def compile_results(self, codemods: list[CodemodExecutorWrapper]): results = [] for codemod in codemods: diff --git a/src/codemodder/file_context.py b/src/codemodder/file_context.py index 103182c8..cdc9290b 100644 --- a/src/codemodder/file_context.py +++ b/src/codemodder/file_context.py @@ -2,22 +2,31 @@ from pathlib import Path from typing import Dict, List -from codemodder.change import Change +from codemodder.change import Change, ChangeSet from codemodder.dependency import Dependency @dataclass -class FileContext: +class FileContext: # pylint: disable=too-many-instance-attributes """ Extra context for running codemods on a given file based on the cli parameters. """ + base_directory: Path file_path: Path line_exclude: List[int] = field(default_factory=list) line_include: List[int] = field(default_factory=list) results_by_id: Dict = field(default_factory=dict) dependencies: set[Dependency] = field(default_factory=set) codemod_changes: List[Change] = field(default_factory=list) + results: List[ChangeSet] = field(default_factory=list) + failures: List[Path] = field(default_factory=list) def add_dependency(self, dependency: Dependency): self.dependencies.add(dependency) + + def add_result(self, result: ChangeSet): + self.results.append(result) + + def add_failure(self, filename: Path): + self.failures.append(filename) diff --git a/src/core_codemods/order_imports.py b/src/core_codemods/order_imports.py index dffed31c..17ecd232 100644 --- a/src/core_codemods/order_imports.py +++ b/src/core_codemods/order_imports.py @@ -42,7 +42,8 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: filtered_blocks.append(block) if filtered_blocks: order_transformer = OrderImportsBlocksTransform( - self.execution_context.directory, filtered_blocks + self.file_context.base_directory, + filtered_blocks, ) result_tree = tree.visit(order_transformer) diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index 1c22fe5e..f94ca0ea 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -23,7 +23,6 @@ ) from codemodder.codemods.utils import Append, ReplaceNodes, get_function_name_node from codemodder.codemods.utils_mixin import NameResolutionMixin -from codemodder.context import CodemodExecutionContext from codemodder.file_context import FileContext parameter_token = "?" @@ -61,14 +60,13 @@ class SQLQueryParameterization(BaseCodemod, UtilsMixin, Codemod): def __init__( self, context: CodemodContext, - execution_context: CodemodExecutionContext, file_context: FileContext, *codemod_args, ) -> None: self.changed_nodes: dict[ cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel ] = {} - BaseCodemod.__init__(self, execution_context, file_context, *codemod_args) + BaseCodemod.__init__(self, file_context, *codemod_args) UtilsMixin.__init__(self, context, {}) Codemod.__init__(self, context) diff --git a/src/core_codemods/upgrade_sslcontext_tls.py b/src/core_codemods/upgrade_sslcontext_tls.py index 53235cd0..935500d3 100644 --- a/src/core_codemods/upgrade_sslcontext_tls.py +++ b/src/core_codemods/upgrade_sslcontext_tls.py @@ -7,7 +7,6 @@ ReviewGuidance, ) from codemodder.change import Change -from codemodder.context import CodemodExecutionContext from codemodder.file_context import FileContext @@ -43,13 +42,8 @@ class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer): PROTOCOL_ARG_INDEX = 0 PROTOCOL_KWARG_NAME = "protocol" - def __init__( - self, - codemod_context: CodemodContext, - execution_context: CodemodExecutionContext, - file_context: FileContext, - ): - SemgrepCodemod.__init__(self, execution_context, file_context) + def __init__(self, codemod_context: CodemodContext, file_context: FileContext): + SemgrepCodemod.__init__(self, file_context) BaseTransformer.__init__(self, codemod_context, self._results) # TODO: apply unused import remover diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 4bf140c7..136687f7 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -36,6 +36,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): repo_manager=mock.MagicMock(), ) self.file_context = FileContext( + root, file_path, [], [], @@ -44,7 +45,6 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( CodemodContext(wrapper=wrapper), - self.execution_context, self.file_context, ) output_tree = command_instance.transform_module(input_tree) @@ -87,6 +87,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): all_results = self.results_by_id_filepath(input_code, file_path) results = all_results[str(file_path)] self.file_context = FileContext( + root, file_path, [], [], @@ -95,7 +96,6 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( CodemodContext(wrapper=wrapper), - self.execution_context, self.file_context, ) output_tree = command_instance.transform_module(input_tree) diff --git a/tests/codemods/test_base_codemod.py b/tests/codemods/test_base_codemod.py index 3d71663b..033e8950 100644 --- a/tests/codemods/test_base_codemod.py +++ b/tests/codemods/test_base_codemod.py @@ -32,7 +32,6 @@ def run_and_assert(self, input_code, expected_output): command_instance = DoNothingCodemod( CodemodContext(), mock.MagicMock(), - mock.MagicMock(), ) output_tree = command_instance.transform_module(input_tree) diff --git a/tests/test_file_context.py b/tests/test_file_context.py index 2ec8277e..b9a365c5 100644 --- a/tests/test_file_context.py +++ b/tests/test_file_context.py @@ -2,6 +2,10 @@ def test_file_context(mocker): - file_context = FileContext(mocker.MagicMock()) + directory = mocker.MagicMock() + path = mocker.MagicMock() + file_context = FileContext(directory, path) + assert file_context.base_directory is directory + assert file_context.file_path is path assert file_context.line_exclude == [] assert file_context.line_include == []