Skip to content

Commit

Permalink
Merge branch 'main' into update-notebooks-settingsfile
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani committed Feb 14, 2024
2 parents 79f6c5b + fd71c7b commit e9542e0
Show file tree
Hide file tree
Showing 9 changed files with 573 additions and 112 deletions.
2 changes: 1 addition & 1 deletion modelscan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def scan(
if not pathlibPath.exists():
raise FileNotFoundError(f"Path {path} does not exist")
else:
modelscan.scan(pathlibPath)
modelscan.scan(path)
else:
raise click.UsageError("Command line must include a path")

Expand Down
13 changes: 11 additions & 2 deletions modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __eq__(self, other: Any) -> bool:
and self.details.module == other.details.module # type: ignore[attr-defined]
and self.details.operator == other.details.operator # type: ignore[attr-defined]
and str(self.details.source) == str(other.details.source) # type: ignore[attr-defined]
and self.details.severity == other.severity # type: ignore[attr-defined]
)

def __repr__(self) -> str:
Expand All @@ -76,6 +77,7 @@ def __hash__(self) -> int:
+ str(self.details.module) # type: ignore[attr-defined]
+ str(self.details.operator) # type: ignore[attr-defined]
+ str(self.details.source) # type: ignore[attr-defined]
+ str(self.details.severity) # type: ignore[attr-defined]
)

def print(self) -> None:
Expand Down Expand Up @@ -121,11 +123,17 @@ def group_by_severity(self) -> Dict[str, List[Issue]]:

class OperatorIssueDetails(IssueDetails):
def __init__(
self, module: str, operator: str, source: Union[Path, str], scanner: str = ""
self,
module: str,
operator: str,
severity: IssueSeverity,
source: Union[Path, str],
scanner: str = "",
) -> None:
self.module = module
self.operator = operator
self.source = source
self.severity = severity
self.scanner = scanner

def output_lines(self) -> List[str]:
Expand All @@ -141,7 +149,8 @@ def output_json(self) -> Dict[str, str]:
"module": f"{self.module}",
"source": f"{str(self.source)}",
"scanner": f"{self.scanner}",
"severity": f"{self.severity.name}",
}

def __repr__(self) -> str:
return f"<OperatorIssueDetails(module={self.module}, operator={self.operator}, source={str(self.source)})>"
return f"<OperatorIssueDetails(module={self.module}, operator={self.operator}, severity={self.severity.name}, source={str(self.source)})>"
49 changes: 29 additions & 20 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def scan(
self._skipped = []
self._scanned = []
self._input_path = str(path)

self._scan_path(Path(path))
pathlibPath = Path().cwd() if path == "." else Path(path).absolute()
self._scan_path(Path(pathlibPath))
return self._generate_results()

def _scan_path(
Expand Down Expand Up @@ -138,13 +138,13 @@ def _scan_zip(
def _generate_results(self) -> Dict[str, Any]:
report: Dict[str, Any] = {}

absolute_path = Path(self._input_path).resolve()
if Path(self._input_path).is_file():
absolute_path = Path(absolute_path).parent

issues_by_severity = self._issues.group_by_severity()
total_issue_count = len(self._issues.all_issues)

report["modelscan_version"] = __version__
report["timestamp"] = datetime.now().isoformat()
report["input_path"] = self._input_path
report["total_issues"] = total_issue_count
report["summary"] = {"total_issues_by_severity": {}}
for severity in IssueSeverity:
if severity.name in issues_by_severity:
Expand All @@ -154,24 +154,33 @@ def _generate_results(self) -> Dict[str, Any]:
else:
report["summary"]["total_issues_by_severity"][severity.name] = 0

report["issues_by_severity"] = {}
for issue_key in issues_by_severity.keys():
report["issues_by_severity"][issue_key] = [
issue.details.output_json() for issue in issues_by_severity[issue_key]
]

report["errors"] = [str(error) for index, error in enumerate(self._errors)]

report["scanned"] = {"total_scanned": len(self._scanned)}
report["scanned"]["scanned_files"] = [
str(file_name) for file_name in self._scanned
report["summary"]["total_issues"] = total_issue_count
report["summary"]["input_path"] = str(self._input_path)
report["summary"]["absolute_path"] = str(absolute_path)
report["summary"]["modelscan_version"] = __version__
report["summary"]["timestamp"] = datetime.now().isoformat()
report["summary"]["skipped"] = {"total_skipped": len(self._skipped)}
report["summary"]["skipped"]["skipped_files"] = [
str(Path(file_name).relative_to(Path(absolute_path)))
for file_name in self._skipped
]
report["summary"]["scanned"] = {"total_scanned": len(self._scanned)}
report["summary"]["scanned"]["scanned_files"] = [
str(Path(file_name).relative_to(Path(absolute_path)))
for file_name in self._scanned
]

report["skipped"] = {"total_skipped": len(self._skipped)}
report["skipped"]["skipped_files"] = [
str(file_name) for file_name in self._skipped
report["issues"] = [
issue.details.output_json() for issue in self._issues.all_issues
]

for issue in report["issues"]:
issue["source"] = str(
Path(issue["source"]).relative_to(Path(absolute_path))
)

report["errors"] = [str(error) for index, error in enumerate(self._errors)]

return report

def is_compatible(self, path: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def generate(
) -> None:
report: Dict[str, Any] = scan._generate_results()
if not settings["show_skipped"]:
del report["skipped"]
del report["summary"]["skipped"]

print(json.dumps(report))

Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def scan(
def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]:
machine_learning_library_name = "Keras"
operators_in_model = self._get_keras_h5_operator_names(source)
if not operators_in_model:
if operators_in_model is None:
return None
return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
Expand Down
5 changes: 4 additions & 1 deletion modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def _check_for_unsafe_tf_keras_operator(
code=IssueCode.UNSAFE_OPERATOR,
severity=severity,
details=OperatorIssueDetails(
module=module_name, operator=op, source=source
module=module_name,
operator=op,
source=source,
severity=severity,
),
)
)
Expand Down
5 changes: 4 additions & 1 deletion modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def scan_pickle_bytes(
code=IssueCode.UNSAFE_OPERATOR,
severity=severity,
details=OperatorIssueDetails(
module=global_module, operator=global_name, source=source
module=global_module,
operator=global_name,
source=source,
severity=severity,
),
)
)
Expand Down
Loading

0 comments on commit e9542e0

Please sign in to comment.