diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index e0b86e22..2d68ee4f 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -35,7 +35,7 @@ def update_code(file_path, new_code): def find_semgrep_results( context: CodemodExecutionContext, codemods: list[CodemodExecutorWrapper], -) -> set[str]: +) -> ResultSet: """Run semgrep once with all configuration files from all codemods and return a set of applicable rule IDs""" yaml_files = list( itertools.chain.from_iterable( @@ -43,10 +43,9 @@ def find_semgrep_results( ) ) if not yaml_files: - return set() + return ResultSet() - results = run_semgrep(context, yaml_files) - return set(results.keys()) + return run_semgrep(context, yaml_files) def create_diff(original_tree: cst.Module, new_tree: cst.Module) -> str: @@ -173,6 +172,29 @@ def analyze_files( execution_context.process_results(codemod.id, analysis_results) +def log_report(context, argv, elapsed_ms, files_to_analyze): + log_section("report") + logger.info("scanned: %s files", len(files_to_analyze)) + all_failures = context.get_failed_files() + logger.info( + "failed: %s files (%s unique)", + len(all_failures), + len(set(all_failures)), + ) + all_changes = context.get_changed_files() + logger.info( + "changed: %s files (%s unique)", + len(all_changes), + len(set(all_changes)), + ) + logger.info("report file: %s", argv.output) + logger.info("total elapsed: %s ms", elapsed_ms) + logger.info(" semgrep: %s ms", context.timer.get_time_ms("semgrep")) + logger.info(" parse: %s ms", context.timer.get_time_ms("parse")) + logger.info(" transform: %s ms", context.timer.get_time_ms("transform")) + logger.info(" write: %s ms", context.timer.get_time_ms("write")) + + def run(original_args) -> int: start = datetime.datetime.now() @@ -225,17 +247,17 @@ def run(original_args) -> int: return 0 full_names = [str(path) for path in files_to_analyze] - logger.debug("matched files:") log_list(logging.DEBUG, "matched files", full_names) - semgrep_results: set[str] = find_semgrep_results(context, codemods_to_run) + semgrep_results: ResultSet = find_semgrep_results(context, codemods_to_run) + semgrep_finding_ids = semgrep_results.all_rule_ids() log_section("scanning") # run codemods one at a time making sure to respect the given sequence for codemod in codemods_to_run: # Unfortunately the IDs from semgrep are not fully specified # TODO: eventually we need to be able to use fully specified IDs here - if codemod.is_semgrep and codemod.name not in semgrep_results: + if codemod.is_semgrep and codemod.name not in semgrep_finding_ids: logger.debug( "no results from semgrep for %s, skipping analysis", codemod.id, @@ -243,7 +265,9 @@ def run(original_args) -> int: continue logger.info("running codemod %s", codemod.id) - results = codemod.apply(context) + semgrep_files = semgrep_results.files_for_rule(codemod.name) + # Non-semgrep codemods ignore the semgrep results + results = codemod.apply(context, semgrep_files) analyze_files( context, files_to_analyze, @@ -260,27 +284,7 @@ def run(original_args) -> int: elapsed_ms = int(elapsed.total_seconds() * 1000) report_default(elapsed_ms, argv, original_args, results) - log_section("report") - logger.info("scanned: %s files", len(files_to_analyze)) - all_failures = context.get_failed_files() - logger.info( - "failed: %s files (%s unique)", - len(all_failures), - len(set(all_failures)), - ) - all_changes = context.get_changed_files() - logger.info( - "changed: %s files (%s unique)", - len(all_changes), - len(set(all_changes)), - ) - logger.info("report file: %s", argv.output) - logger.info("total elapsed: %s ms", elapsed_ms) - logger.info("semgrep: %s ms", context.timer.get_time_ms("semgrep")) - logger.info("parse: %s ms", context.timer.get_time_ms("parse")) - logger.info("transform: %s ms", context.timer.get_time_ms("transform")) - logger.info("write: %s ms", context.timer.get_time_ms("write")) - + log_report(context, argv, elapsed_ms, files_to_analyze) return 0 diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index 86f7430b..b83f00c4 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -112,8 +112,9 @@ def apply_rule(cls, context, *args, **kwargs) -> ResultSet: Apply semgrep to gather rule results """ yaml_files = kwargs.get("yaml_files") or args[0] + files_to_analyze = kwargs.get("files_to_analyze") or args[1] with context.timer.measure("semgrep"): - return semgrep_run(context, yaml_files) + return semgrep_run(context, yaml_files, files_to_analyze) @property def should_transform(self): diff --git a/src/codemodder/executor.py b/src/codemodder/executor.py index 47910b9f..b5420b30 100644 --- a/src/codemodder/executor.py +++ b/src/codemodder/executor.py @@ -1,4 +1,5 @@ from importlib.abc import Traversable +from pathlib import Path from wrapt import CallableObjectProxy @@ -22,13 +23,17 @@ def __init__( self.docs_module = docs_module self.semgrep_config_module = semgrep_config_module - def apply(self, context): + def apply(self, context, files: list[Path]): """ Wraps the codemod's apply method to inject additional arguments. Not all codemods will need these arguments. """ - return self.apply_rule(context, yaml_files=self.yaml_files) + return self.apply_rule( + context, + yaml_files=self.yaml_files, + files_to_analyze=files, + ) @property def name(self): diff --git a/src/codemodder/result.py b/src/codemodder/result.py index 4566d7be..2409a7b6 100644 --- a/src/codemodder/result.py +++ b/src/codemodder/result.py @@ -34,13 +34,16 @@ class Result(ABCDataclass): locations: list[Location] -class ResultSet(dict[str, list[Result]]): +class ResultSet(dict[str, dict[Path, list[Result]]]): def add_result(self, result: Result): - self.setdefault(result.rule_id, []).append(result) + for loc in result.locations: + self.setdefault(result.rule_id, {}).setdefault(loc.file, []).append(result) def results_for_rule_and_file(self, rule_id: str, file: Path) -> list[Result]: - return [ - result - for result in self.get(rule_id, []) - if result.locations[0].file == file - ] + return self.get(rule_id, {}).get(file, []) + + def files_for_rule(self, rule_id: str) -> list[Path]: + return list(self.get(rule_id, {}).keys()) + + def all_rule_ids(self) -> list[str]: + return list(self.keys()) diff --git a/src/codemodder/semgrep.py b/src/codemodder/semgrep.py index 8138cb9e..e4fc32a3 100644 --- a/src/codemodder/semgrep.py +++ b/src/codemodder/semgrep.py @@ -1,7 +1,7 @@ import subprocess import itertools from tempfile import NamedTemporaryFile -from typing import Iterable +from typing import Iterable, Optional from pathlib import Path from codemodder.context import CodemodExecutionContext from codemodder.sarifs import SarifResultSet @@ -11,6 +11,7 @@ def run( execution_context: CodemodExecutionContext, yaml_files: Iterable[Path], + files_to_analyze: Optional[Iterable[Path]] = None, ) -> SarifResultSet: """ Runs Semgrep and outputs a dict with the results organized by rule_id. @@ -34,7 +35,7 @@ def run( map(lambda f: ["--config", str(f)], yaml_files) ) ) - command.append(str(execution_context.directory)) + command.extend(map(str, files_to_analyze or [execution_context.directory])) logger.debug("semgrep command: `%s`", " ".join(command)) subprocess.run( command, diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index ecef8479..78d7fbf0 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -84,12 +84,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): ) input_tree = cst.parse_module(input_code) all_results = self.results_by_id_filepath(input_code, file_path) - results = [ - result - for entry in all_results.values() - for result in entry - if result.rule_id == self.codemod.name() - ] + results = all_results.results_for_rule_and_file(self.codemod.name(), file_path) self.file_context = FileContext( root, file_path, diff --git a/tests/test_codemodder.py b/tests/test_codemodder.py index 02f5149f..dd73f6a9 100644 --- a/tests/test_codemodder.py +++ b/tests/test_codemodder.py @@ -6,6 +6,7 @@ from codemodder.codemodder import create_diff, run, find_semgrep_results from codemodder.semgrep import run as semgrep_run from codemodder.registry import load_registered_codemods +from codemodder.result import ResultSet class TestRun: @@ -190,7 +191,7 @@ def test_find_semgrep_results_no_yaml(self, mocker): codemod_include=["use-defusedxml"] ) result = find_semgrep_results(mocker.MagicMock(), codemods) - assert result == set() + assert result == ResultSet() assert run_semgrep.call_count == 0 def test_diff_newline_edge_case(self): diff --git a/tests/test_sarif_processing.py b/tests/test_sarif_processing.py index 6f837c9a..da927a18 100644 --- a/tests/test_sarif_processing.py +++ b/tests/test_sarif_processing.py @@ -36,7 +36,12 @@ def test_results_by_rule_id(self): assert list(results.keys()) == [expected_rule] expected_path = Path("tests/samples/insecure_random.py") - assert expected_path == results[expected_rule][0].locations[0].file + assert list(results[expected_rule].keys()) == [expected_path] + + assert results[expected_rule][expected_path][0].rule_id == expected_rule + assert ( + results[expected_rule][expected_path][0].locations[0].file == expected_path + ) def test_codeql_sarif_input(self, tmpdir): completed_process = subprocess.run(