Skip to content

Commit

Permalink
Refactor: semgrep results use specific datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Nov 6, 2023
1 parent f20e61f commit bd3e1d4
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 105 deletions.
20 changes: 12 additions & 8 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from codemodder.executor import CodemodExecutorWrapper
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
from codemodder.report.codetf_reporter import report_default
from codemodder.result import ResultSet
from codemodder.semgrep import run as run_semgrep


Expand All @@ -45,7 +46,7 @@ def find_semgrep_results(
return set()

results = run_semgrep(context, yaml_files)
return {rule_id for file_changes in results.values() for rule_id in file_changes}
return set(results.keys())


def create_diff(original_tree: cst.Module, new_tree: cst.Module) -> str:
Expand Down Expand Up @@ -103,7 +104,7 @@ def process_file(
file_path: Path,
base_directory: Path,
codemod,
sarif,
results: ResultSet,
cli_args,
): # pylint: disable=too-many-arguments
logger.debug("scanning file %s", file_path)
Expand All @@ -112,14 +113,17 @@ def process_file(

line_exclude = file_line_patterns(file_path, cli_args.path_exclude)
line_include = file_line_patterns(file_path, cli_args.path_include)
sarif_for_file = sarif.get(str(file_path)) or {}
findings_for_rule = results.results_for_rule_and_file(
codemod.name, # TODO: should be full ID
file_path,
)

file_context = FileContext(
base_directory,
file_path,
line_exclude,
line_include,
sarif_for_file,
findings_for_rule,
)

try:
Expand All @@ -146,27 +150,27 @@ def analyze_files(
execution_context: CodemodExecutionContext,
files_to_analyze,
codemod,
sarif,
results: ResultSet,
cli_args,
):
with ThreadPoolExecutor(max_workers=cli_args.max_workers) as executor:
logger.debug(
"using executor with %s threads",
cli_args.max_workers,
)
results = executor.map(
analysis_results = executor.map(
lambda args: process_file(
args[0],
args[1],
execution_context.directory,
codemod,
sarif,
results,
cli_args,
),
enumerate(files_to_analyze),
)
executor.shutdown(wait=True)
execution_context.process_results(codemod.id, results)
execution_context.process_results(codemod.id, analysis_results)


def run(original_args) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init_subclass__(cls):
def __init__(self, codemod_context: CodemodContext, file_context: FileContext):
BaseCodemod.__init__(self, codemod_context, file_context)
_SemgrepCodemod.__init__(self, file_context)
BaseTransformer.__init__(self, codemod_context, self._results)
BaseTransformer.__init__(self, codemod_context, file_context.findings)

def _new_or_updated_node(self, original_node, updated_node):
if self.node_is_selected(original_node):
Expand Down
25 changes: 7 additions & 18 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from codemodder.change import Change
from codemodder.dependency import Dependency
from codemodder.file_context import FileContext
from codemodder.result import ResultSet
from codemodder.semgrep import run as semgrep_run


Expand Down Expand Up @@ -51,14 +52,14 @@ def __init__(self, file_context: FileContext):
self.file_context = file_context

@classmethod
def apply_rule(cls, context, *args, **kwargs):
def apply_rule(cls, context, *args, **kwargs) -> ResultSet:
"""
Apply rule associated with this codemod and gather results
Does nothing by default. Subclasses may override for custom rule logic.
"""
del context, args, kwargs
return {}
return ResultSet()

@classmethod
def name(cls):
Expand Down Expand Up @@ -105,28 +106,16 @@ class SemgrepCodemod(BaseCodemod):
YAML_FILES: ClassVar[List[str]] = NotImplemented
is_semgrep = True

def __init__(self, *args):
super().__init__(*args)
self._results = (
self.file_context.results_by_id.get(
self.METADATA.NAME # pylint: disable=no-member
)
or []
)

@classmethod
def apply_rule(
cls, context, yaml_files, *args, **kwargs
): # pylint: disable=arguments-differ
def apply_rule(cls, context, *args, **kwargs) -> ResultSet:
"""
Apply semgrep to gather rule results
"""
del args, kwargs
yaml_files = kwargs.get("yaml_files") or args[0]
with context.timer.measure("semgrep"):
return semgrep_run(context, yaml_files)

@property
def should_transform(self):
"""Semgrep codemods should attempt transform only if there are
semgrep results"""
return bool(self._results)
"""Semgrep codemods should attempt transform only if there are semgrep results"""
return bool(self.file_context.findings)
52 changes: 22 additions & 30 deletions src/codemodder/codemods/base_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
from libcst.codemod import ContextAwareVisitor, VisitorBasedCodemodCommand
from libcst.metadata import PositionProvider

from codemodder.result import Result


class UtilsMixin:
METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,)
results: list[Result]

def __init__(self, context, results):
super().__init__(context)
def __init__(self, results: list[Result]):
self.results = results

def filter_by_result(self, pos_to_match):
all_pos = [extract_pos_from_result(result) for result in self.results]
return any(match_pos(pos_to_match, position) for position in all_pos)
return any(
location.match(pos_to_match)
for result in self.results
for location in result.locations
)

def filter_by_path_includes_or_excludes(self, pos_to_match):
"""
Expand All @@ -28,6 +32,7 @@ def filter_by_path_includes_or_excludes(self, pos_to_match):
def node_is_selected(self, node) -> bool:
if not self.results:
return False

pos_to_match = self.node_position(node)
return self.filter_by_result(
pos_to_match
Expand All @@ -41,34 +46,21 @@ def lineno_for_node(self, node):
return self.node_position(node).start.line


class BaseTransformer(UtilsMixin, VisitorBasedCodemodCommand):
...
class BaseTransformer(VisitorBasedCodemodCommand, UtilsMixin):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,)

def __init__(self, context, results: list[Result]):
super().__init__(context)
UtilsMixin.__init__(self, results)


class BaseVisitor(ContextAwareVisitor, UtilsMixin):
METADATA_DEPENDENCIES: Tuple[Any, ...] = (PositionProvider,)

class BaseVisitor(UtilsMixin, ContextAwareVisitor):
...
def __init__(self, context, results: list[Result]):
super().__init__(context)
UtilsMixin.__init__(self, results)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line


def extract_pos_from_result(result):
region = result["locations"][0]["physicalLocation"]["region"]
# TODO it may be the case some of these attributes do not exist
return (
region.get("startLine"),
region["startColumn"],
region.get("endLine") or region.get("startLine"),
region["endColumn"],
)


def match_pos(pos, x):
# needs some leeway because the semgrep and libcst won't exactly match
return (
pos.start.line == x[0]
and (pos.start.column in (x[1] - 1, x[1]))
and pos.end.line == x[2]
and (pos.end.column in (x[3] - 1, x[3]))
)
5 changes: 3 additions & 2 deletions src/codemodder/file_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List
from typing import List

from codemodder.change import Change, ChangeSet
from codemodder.dependency import Dependency
from codemodder.result import Result
from codemodder.utils.timer import Timer


Expand All @@ -17,7 +18,7 @@ class FileContext: # pylint: disable=too-many-instance-attributes
file_path: Path
line_exclude: List[int] = field(default_factory=list)
line_include: List[int] = field(default_factory=list)
results_by_id: Dict = field(default_factory=dict)
findings: List[Result] = field(default_factory=list)
dependencies: set[Dependency] = field(default_factory=set)
codemod_changes: List[Change] = field(default_factory=list)
results: List[ChangeSet] = field(default_factory=list)
Expand Down
46 changes: 46 additions & 0 deletions src/codemodder/result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from pathlib import Path

from .utils.abc_dataclass import ABCDataclass


@dataclass
class LineInfo:
line: int
column: int
snippet: str | None


@dataclass
class Location(ABCDataclass):
file: Path
start: LineInfo
end: LineInfo

def match(self, pos):
start_column = self.start.column
end_column = self.end.column
return (
pos.start.line == self.start.line
and (pos.start.column in (start_column - 1, start_column))
and pos.end.line == self.end.line
and (pos.end.column in (end_column - 1, end_column))
)


@dataclass
class Result(ABCDataclass):
rule_id: str
locations: list[Location]


class ResultSet(dict[str, list[Result]]):
def add_result(self, result: Result):
self.setdefault(result.rule_id, []).append(result)

def results_for_rule_and_file(self, rule_id: str, file: Path) -> list[Result]:
return [
result
for result in self.get(rule_id, [])
if result.locations[0].file == file
]
80 changes: 52 additions & 28 deletions src/codemodder/sarifs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections import defaultdict
import json
from typing import Union
from pathlib import Path
from typing import Optional

from typing_extensions import Self

def extract_rule_id(result, sarif_run) -> Union[str, None]:
from .result import ResultSet, Result, Location, LineInfo


def extract_rule_id(result, sarif_run) -> Optional[str]:
if "ruleId" in result:
# semgrep preprends the folders into the rule-id, we want the base name only
return result["ruleId"].rsplit(".")[-1]
Expand All @@ -17,28 +21,48 @@ def extract_rule_id(result, sarif_run) -> Union[str, None]:
return None


def results_by_path_and_rule_id(sarif_file):
"""
Extract all the results of a sarif file and organize them by id.
"""
with open(sarif_file, "r", encoding="utf-8") as f:
data = json.load(f)

path_and_ruleid_dict = defaultdict(lambda: defaultdict(list))
for sarif_run in data["runs"]:
results = sarif_run["results"]

path_dict = defaultdict(list)
for r in results:
path = r["locations"][0]["physicalLocation"]["artifactLocation"]["uri"]
path_dict.setdefault(path, []).append(r)

for path in path_dict.keys():
rule_id_dict = defaultdict(list)
for r in path_dict.get(path):
# semgrep preprends the folders into the rule-id, we want the base name only
rule_id = extract_rule_id(r, sarif_run)
if rule_id:
rule_id_dict.setdefault(rule_id, []).append(r)
path_and_ruleid_dict[path].update(rule_id_dict)
return path_and_ruleid_dict
class SarifLocation(Location):
@classmethod
def from_sarif(cls, sarif_location) -> Self:
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
file = Path(artifact_location["uri"])
start = LineInfo(
line=sarif_location["physicalLocation"]["region"]["startLine"],
column=sarif_location["physicalLocation"]["region"]["startColumn"],
snippet=sarif_location["physicalLocation"]["region"]["snippet"]["text"],
)
end = LineInfo(
line=sarif_location["physicalLocation"]["region"]["endLine"],
column=sarif_location["physicalLocation"]["region"]["endColumn"],
snippet=sarif_location["physicalLocation"]["region"]["snippet"]["text"],
)
return cls(file=file, start=start, end=end)


class SarifResult(Result):
@classmethod
def from_sarif(cls, sarif_result, sarif_run) -> Self:
rule_id = extract_rule_id(sarif_result, sarif_run)
if not rule_id:
raise ValueError("Could not extract rule id from sarif result.")

locations: list[Location] = []
for location in sarif_result["locations"]:
artifact_location = SarifLocation.from_sarif(location)
locations.append(artifact_location)
return cls(rule_id=rule_id, locations=locations)


class SarifResultSet(ResultSet):
@classmethod
def from_sarif(cls, sarif_file: str | Path) -> Self:
with open(sarif_file, "r", encoding="utf-8") as f:
data = json.load(f)

result_set = cls()
for sarif_run in data["runs"]:
for result in sarif_run["results"]:
sarif_result = SarifResult.from_sarif(result, sarif_run)
result_set.add_result(sarif_result)

return result_set
Loading

0 comments on commit bd3e1d4

Please sign in to comment.