From bd3e1d438e866fcafd74c07d6d1bc3be0c47540d Mon Sep 17 00:00:00 2001 From: Daniel D'Avella Date: Tue, 31 Oct 2023 10:39:24 -0400 Subject: [PATCH] Refactor: semgrep results use specific datatypes --- src/codemodder/codemodder.py | 20 +++-- src/codemodder/codemods/api/__init__.py | 2 +- src/codemodder/codemods/base_codemod.py | 25 ++---- src/codemodder/codemods/base_visitor.py | 52 +++++------- src/codemodder/file_context.py | 5 +- src/codemodder/result.py | 46 +++++++++++ src/codemodder/sarifs.py | 80 ++++++++++++------- src/codemodder/semgrep.py | 9 ++- src/codemodder/utils/abc_dataclass.py | 13 +++ src/core_codemods/django_debug_flag_on.py | 2 +- .../django_session_cookie_secure_off.py | 2 +- src/core_codemods/sql_parameterization.py | 2 +- src/core_codemods/url_sandbox.py | 4 +- tests/codemods/base_codemod_test.py | 10 ++- tests/test_sarif_processing.py | 15 ++-- 15 files changed, 182 insertions(+), 105 deletions(-) create mode 100644 src/codemodder/result.py create mode 100644 src/codemodder/utils/abc_dataclass.py diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 774f32b2..e0b86e22 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -20,6 +20,7 @@ from codemodder.executor import CodemodExecutorWrapper from codemodder.project_analysis.python_repo_manager import PythonRepoManager from codemodder.report.codetf_reporter import report_default +from codemodder.result import ResultSet from codemodder.semgrep import run as run_semgrep @@ -45,7 +46,7 @@ def find_semgrep_results( return set() results = run_semgrep(context, yaml_files) - return {rule_id for file_changes in results.values() for rule_id in file_changes} + return set(results.keys()) def create_diff(original_tree: cst.Module, new_tree: cst.Module) -> str: @@ -103,7 +104,7 @@ def process_file( file_path: Path, base_directory: Path, codemod, - sarif, + results: ResultSet, cli_args, ): # pylint: disable=too-many-arguments logger.debug("scanning file %s", file_path) @@ -112,14 +113,17 @@ def process_file( line_exclude = file_line_patterns(file_path, cli_args.path_exclude) line_include = file_line_patterns(file_path, cli_args.path_include) - sarif_for_file = sarif.get(str(file_path)) or {} + findings_for_rule = results.results_for_rule_and_file( + codemod.name, # TODO: should be full ID + file_path, + ) file_context = FileContext( base_directory, file_path, line_exclude, line_include, - sarif_for_file, + findings_for_rule, ) try: @@ -146,7 +150,7 @@ def analyze_files( execution_context: CodemodExecutionContext, files_to_analyze, codemod, - sarif, + results: ResultSet, cli_args, ): with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor: @@ -154,19 +158,19 @@ def analyze_files( "using executor with %s threads", cli_args.max_workers, ) - results = executor.map( + analysis_results = executor.map( lambda args: process_file( args[0], args[1], execution_context.directory, codemod, - sarif, + results, cli_args, ), enumerate(files_to_analyze), ) executor.shutdown(wait=True) - execution_context.process_results(codemod.id, results) + execution_context.process_results(codemod.id, analysis_results) def run(original_args) -> int: diff --git a/src/codemodder/codemods/api/__init__.py b/src/codemodder/codemods/api/__init__.py index 1d141bc9..a3a102c3 100644 --- a/src/codemodder/codemods/api/__init__.py +++ b/src/codemodder/codemods/api/__init__.py @@ -109,7 +109,7 @@ def __init_subclass__(cls): def __init__(self, codemod_context: CodemodContext, file_context: FileContext): BaseCodemod.__init__(self, codemod_context, file_context) _SemgrepCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, self._results) + BaseTransformer.__init__(self, codemod_context, file_context.findings) def _new_or_updated_node(self, original_node, updated_node): if self.node_is_selected(original_node): diff --git a/src/codemodder/codemods/base_codemod.py b/src/codemodder/codemods/base_codemod.py index 127d9a9b..86f7430b 100644 --- a/src/codemodder/codemods/base_codemod.py +++ b/src/codemodder/codemods/base_codemod.py @@ -7,6 +7,7 @@ from codemodder.change import Change from codemodder.dependency import Dependency from codemodder.file_context import FileContext +from codemodder.result import ResultSet from codemodder.semgrep import run as semgrep_run @@ -51,14 +52,14 @@ def __init__(self, file_context: FileContext): self.file_context = file_context @classmethod - def apply_rule(cls, context, *args, **kwargs): + def apply_rule(cls, context, *args, **kwargs) -> ResultSet: """ Apply rule associated with this codemod and gather results Does nothing by default. Subclasses may override for custom rule logic. """ del context, args, kwargs - return {} + return ResultSet() @classmethod def name(cls): @@ -105,28 +106,16 @@ class SemgrepCodemod(BaseCodemod): YAML_FILES: ClassVar[List[str]] = NotImplemented is_semgrep = True - def __init__(self, *args): - super().__init__(*args) - self._results = ( - self.file_context.results_by_id.get( - self.METADATA.NAME # pylint: disable=no-member - ) - or [] - ) - @classmethod - def apply_rule( - cls, context, yaml_files, *args, **kwargs - ): # pylint: disable=arguments-differ + def apply_rule(cls, context, *args, **kwargs) -> ResultSet: """ Apply semgrep to gather rule results """ - del args, kwargs + yaml_files = kwargs.get("yaml_files") or args[0] with context.timer.measure("semgrep"): return semgrep_run(context, yaml_files) @property def should_transform(self): - """Semgrep codemods should attempt transform only if there are - semgrep results""" - return bool(self._results) + """Semgrep codemods should attempt transform only if there are semgrep results""" + return bool(self.file_context.findings) diff --git a/src/codemodder/codemods/base_visitor.py b/src/codemodder/codemods/base_visitor.py index 07b22b69..65d30d1d 100644 --- a/src/codemodder/codemods/base_visitor.py +++ b/src/codemodder/codemods/base_visitor.py @@ -2,17 +2,21 @@ from libcst.codemod import ContextAwareVisitor, VisitorBasedCodemodCommand from libcst.metadata import PositionProvider +from codemodder.result import Result + class UtilsMixin: - METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,) + results: list[Result] - def __init__(self, context, results): - super().__init__(context) + def __init__(self, results: list[Result]): self.results = results def filter_by_result(self, pos_to_match): - all_pos = [extract_pos_from_result(result) for result in self.results] - return any(match_pos(pos_to_match, position) for position in all_pos) + return any( + location.match(pos_to_match) + for result in self.results + for location in result.locations + ) def filter_by_path_includes_or_excludes(self, pos_to_match): """ @@ -28,6 +32,7 @@ def filter_by_path_includes_or_excludes(self, pos_to_match): def node_is_selected(self, node) -> bool: if not self.results: return False + pos_to_match = self.node_position(node) return self.filter_by_result( pos_to_match @@ -41,34 +46,21 @@ def lineno_for_node(self, node): return self.node_position(node).start.line -class BaseTransformer(UtilsMixin, VisitorBasedCodemodCommand): - ... +class BaseTransformer(VisitorBasedCodemodCommand, UtilsMixin): + METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,) + + def __init__(self, context, results: list[Result]): + super().__init__(context) + UtilsMixin.__init__(self, results) + +class BaseVisitor(ContextAwareVisitor, UtilsMixin): + METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,) -class BaseVisitor(UtilsMixin, ContextAwareVisitor): - ... + def __init__(self, context, results: list[Result]): + super().__init__(context) + UtilsMixin.__init__(self, results) def match_line(pos, line): return pos.start.line == line and pos.end.line == line - - -def extract_pos_from_result(result): - region = result["locations"][0]["physicalLocation"]["region"] - # TODO it may be the case some of these attributes do not exist - return ( - region.get("startLine"), - region["startColumn"], - region.get("endLine") or region.get("startLine"), - region["endColumn"], - ) - - -def match_pos(pos, x): - # needs some leeway because the semgrep and libcst won't exactly match - return ( - pos.start.line == x[0] - and (pos.start.column in (x[1] - 1, x[1])) - and pos.end.line == x[2] - and (pos.end.column in (x[3] - 1, x[3])) - ) diff --git a/src/codemodder/file_context.py b/src/codemodder/file_context.py index 14de36b8..9cd5e403 100644 --- a/src/codemodder/file_context.py +++ b/src/codemodder/file_context.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List +from typing import List from codemodder.change import Change, ChangeSet from codemodder.dependency import Dependency +from codemodder.result import Result from codemodder.utils.timer import Timer @@ -17,7 +18,7 @@ class FileContext: # pylint: disable=too-many-instance-attributes file_path: Path line_exclude: List[int] = field(default_factory=list) line_include: List[int] = field(default_factory=list) - results_by_id: Dict = field(default_factory=dict) + findings: List[Result] = field(default_factory=list) dependencies: set[Dependency] = field(default_factory=set) codemod_changes: List[Change] = field(default_factory=list) results: List[ChangeSet] = field(default_factory=list) diff --git a/src/codemodder/result.py b/src/codemodder/result.py new file mode 100644 index 00000000..4566d7be --- /dev/null +++ b/src/codemodder/result.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +from pathlib import Path + +from .utils.abc_dataclass import ABCDataclass + + +@dataclass +class LineInfo: + line: int + column: int + snippet: str | None + + +@dataclass +class Location(ABCDataclass): + file: Path + start: LineInfo + end: LineInfo + + def match(self, pos): + start_column = self.start.column + end_column = self.end.column + return ( + pos.start.line == self.start.line + and (pos.start.column in (start_column - 1, start_column)) + and pos.end.line == self.end.line + and (pos.end.column in (end_column - 1, end_column)) + ) + + +@dataclass +class Result(ABCDataclass): + rule_id: str + locations: list[Location] + + +class ResultSet(dict[str, list[Result]]): + def add_result(self, result: Result): + self.setdefault(result.rule_id, []).append(result) + + def results_for_rule_and_file(self, rule_id: str, file: Path) -> list[Result]: + return [ + result + for result in self.get(rule_id, []) + if result.locations[0].file == file + ] diff --git a/src/codemodder/sarifs.py b/src/codemodder/sarifs.py index 86da018a..56a38080 100644 --- a/src/codemodder/sarifs.py +++ b/src/codemodder/sarifs.py @@ -1,9 +1,13 @@ -from collections import defaultdict import json -from typing import Union +from pathlib import Path +from typing import Optional +from typing_extensions import Self -def extract_rule_id(result, sarif_run) -> Union[str, None]: +from .result import ResultSet, Result, Location, LineInfo + + +def extract_rule_id(result, sarif_run) -> Optional[str]: if "ruleId" in result: # semgrep preprends the folders into the rule-id, we want the base name only return result["ruleId"].rsplit(".")[-1] @@ -17,28 +21,48 @@ def extract_rule_id(result, sarif_run) -> Union[str, None]: return None -def results_by_path_and_rule_id(sarif_file): - """ - Extract all the results of a sarif file and organize them by id. - """ - with open(sarif_file, "r", encoding="utf-8") as f: - data = json.load(f) - - path_and_ruleid_dict = defaultdict(lambda: defaultdict(list)) - for sarif_run in data["runs"]: - results = sarif_run["results"] - - path_dict = defaultdict(list) - for r in results: - path = r["locations"][0]["physicalLocation"]["artifactLocation"]["uri"] - path_dict.setdefault(path, []).append(r) - - for path in path_dict.keys(): - rule_id_dict = defaultdict(list) - for r in path_dict.get(path): - # semgrep preprends the folders into the rule-id, we want the base name only - rule_id = extract_rule_id(r, sarif_run) - if rule_id: - rule_id_dict.setdefault(rule_id, []).append(r) - path_and_ruleid_dict[path].update(rule_id_dict) - return path_and_ruleid_dict +class SarifLocation(Location): + @classmethod + def from_sarif(cls, sarif_location) -> Self: + artifact_location = sarif_location["physicalLocation"]["artifactLocation"] + file = Path(artifact_location["uri"]) + start = LineInfo( + line=sarif_location["physicalLocation"]["region"]["startLine"], + column=sarif_location["physicalLocation"]["region"]["startColumn"], + snippet=sarif_location["physicalLocation"]["region"]["snippet"]["text"], + ) + end = LineInfo( + line=sarif_location["physicalLocation"]["region"]["endLine"], + column=sarif_location["physicalLocation"]["region"]["endColumn"], + snippet=sarif_location["physicalLocation"]["region"]["snippet"]["text"], + ) + return cls(file=file, start=start, end=end) + + +class SarifResult(Result): + @classmethod + def from_sarif(cls, sarif_result, sarif_run) -> Self: + rule_id = extract_rule_id(sarif_result, sarif_run) + if not rule_id: + raise ValueError("Could not extract rule id from sarif result.") + + locations: list[Location] = [] + for location in sarif_result["locations"]: + artifact_location = SarifLocation.from_sarif(location) + locations.append(artifact_location) + return cls(rule_id=rule_id, locations=locations) + + +class SarifResultSet(ResultSet): + @classmethod + def from_sarif(cls, sarif_file: str | Path) -> Self: + with open(sarif_file, "r", encoding="utf-8") as f: + data = json.load(f) + + result_set = cls() + for sarif_run in data["runs"]: + for result in sarif_run["results"]: + sarif_result = SarifResult.from_sarif(result, sarif_run) + result_set.add_result(sarif_result) + + return result_set diff --git a/src/codemodder/semgrep.py b/src/codemodder/semgrep.py index 1c08ea60..8138cb9e 100644 --- a/src/codemodder/semgrep.py +++ b/src/codemodder/semgrep.py @@ -4,11 +4,14 @@ from typing import Iterable from pathlib import Path from codemodder.context import CodemodExecutionContext -from codemodder.sarifs import results_by_path_and_rule_id +from codemodder.sarifs import SarifResultSet from codemodder.logging import logger -def run(execution_context: CodemodExecutionContext, yaml_files: Iterable[Path]) -> dict: +def run( + execution_context: CodemodExecutionContext, + yaml_files: Iterable[Path], +) -> SarifResultSet: """ Runs Semgrep and outputs a dict with the results organized by rule_id. """ @@ -40,5 +43,5 @@ def run(execution_context: CodemodExecutionContext, yaml_files: Iterable[Path]) stdout=None if execution_context.verbose else subprocess.PIPE, stderr=None if execution_context.verbose else subprocess.PIPE, ) - results = results_by_path_and_rule_id(temp_sarif_file.name) + results = SarifResultSet.from_sarif(temp_sarif_file.name) return results diff --git a/src/codemodder/utils/abc_dataclass.py b/src/codemodder/utils/abc_dataclass.py new file mode 100644 index 00000000..2b581b22 --- /dev/null +++ b/src/codemodder/utils/abc_dataclass.py @@ -0,0 +1,13 @@ +from abc import ABC +from dataclasses import dataclass + + +@dataclass +class ABCDataclass(ABC): + """Inspired by https://stackoverflow.com/a/60669138""" + + def __new__(cls, *args, **kwargs): + del args, kwargs + if cls == ABCDataclass or cls.__bases__[0] == ABCDataclass: + raise TypeError("Cannot instantiate abstract class.") + return super().__new__(cls) diff --git a/src/core_codemods/django_debug_flag_on.py b/src/core_codemods/django_debug_flag_on.py index 0d18c36a..b7d18d1b 100644 --- a/src/core_codemods/django_debug_flag_on.py +++ b/src/core_codemods/django_debug_flag_on.py @@ -45,7 +45,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: # checks if the file we looking is a settings.py file from django's default directory structure if is_django_settings_file(self.file_context.file_path): debug_flag_transformer = DebugFlagTransformer( - self.context, self.file_context, self._results + self.context, self.file_context, self.file_context.findings ) new_tree = debug_flag_transformer.transform_module(tree) if debug_flag_transformer.changes_in_file: diff --git a/src/core_codemods/django_session_cookie_secure_off.py b/src/core_codemods/django_session_cookie_secure_off.py index c5abdf69..c6c3a0c7 100644 --- a/src/core_codemods/django_session_cookie_secure_off.py +++ b/src/core_codemods/django_session_cookie_secure_off.py @@ -44,7 +44,7 @@ def __init__(self, codemod_context: CodemodContext, *args): def transform_module_impl(self, tree: cst.Module) -> cst.Module: if is_django_settings_file(self.file_context.file_path): transformer = SessionCookieSecureTransformer( - self.context, self.file_context, self._results + self.context, self.file_context, self.file_context.findings ) new_tree = transformer.transform_module(tree) if transformer.changes_in_file: diff --git a/src/core_codemods/sql_parameterization.py b/src/core_codemods/sql_parameterization.py index f94ca0ea..54be3829 100644 --- a/src/core_codemods/sql_parameterization.py +++ b/src/core_codemods/sql_parameterization.py @@ -67,7 +67,7 @@ def __init__( cst.CSTNode, cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel ] = {} BaseCodemod.__init__(self, file_context, *codemod_args) - UtilsMixin.__init__(self, context, {}) + UtilsMixin.__init__(self, []) Codemod.__init__(self, context) def _build_param_element(self, middle, index: int) -> cst.BaseExpression: diff --git a/src/core_codemods/url_sandbox.py b/src/core_codemods/url_sandbox.py index 45c078be..4c9bf218 100644 --- a/src/core_codemods/url_sandbox.py +++ b/src/core_codemods/url_sandbox.py @@ -66,7 +66,9 @@ def __init__(self, codemod_context: CodemodContext, *args): def transform_module_impl(self, tree: cst.Module) -> cst.Module: # we first gather all the nodes we want to change together with their replacements find_requests_visitor = FindRequestCallsAndImports( - self.context, self.file_context, self._results + self.context, + self.file_context, + self.file_context.findings, ) tree.visit(find_requests_visitor) if find_requests_visitor.nodes_to_change: diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 136687f7..ecef8479 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -1,5 +1,4 @@ # pylint: disable=no-member,not-callable,attribute-defined-outside-init -from collections import defaultdict import os from pathlib import Path from textwrap import dedent @@ -40,7 +39,7 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): file_path, [], [], - defaultdict(list), + [], ) wrapper = cst.MetadataWrapper(input_tree) command_instance = self.codemod( @@ -85,7 +84,12 @@ def run_and_assert_filepath(self, root, file_path, input_code, expected): ) input_tree = cst.parse_module(input_code) all_results = self.results_by_id_filepath(input_code, file_path) - results = all_results[str(file_path)] + results = [ + result + for entry in all_results.values() + for result in entry + if result.rule_id == self.codemod.name() + ] self.file_context = FileContext( root, file_path, diff --git a/tests/test_sarif_processing.py b/tests/test_sarif_processing.py index fe1f7a2b..6f837c9a 100644 --- a/tests/test_sarif_processing.py +++ b/tests/test_sarif_processing.py @@ -1,5 +1,5 @@ from codemodder.sarifs import extract_rule_id -from codemodder.sarifs import results_by_path_and_rule_id +from codemodder.sarifs import SarifResultSet from pathlib import Path import subprocess import json @@ -28,16 +28,15 @@ def test_extract_rule_id_semgrep(self): rule_id = extract_rule_id(result, sarif_run) assert rule_id == "secure-random" - def test_results_by_path_and_rule_id(self): + def test_results_by_rule_id(self): sarif_file = Path("tests") / "samples" / "semgrep.sarif" - results = results_by_path_and_rule_id(sarif_file) - expected_path = "tests/samples/insecure_random.py" - assert list(results.keys()) == [expected_path] - + results = SarifResultSet.from_sarif(sarif_file) expected_rule = "secure-random" - rule_in_path = next(iter(results[expected_path].keys())) - assert expected_rule == rule_in_path + assert list(results.keys()) == [expected_rule] + + expected_path = Path("tests/samples/insecure_random.py") + assert expected_path == results[expected_rule][0].locations[0].file def test_codeql_sarif_input(self, tmpdir): completed_process = subprocess.run(