diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index afb5b2d..a5f976c 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -41,11 +41,17 @@ def scan( ) return None - return self.label_results(self._scan_keras_h5_file(source)) + results = self._scan_keras_h5_file(source) + if results: + return self.label_results(results) + else: + return None - def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults: + def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]: machine_learning_library_name = "Keras" operators_in_model = self._get_keras_h5_operator_names(source) + if not operators_in_model: + return None return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, @@ -55,11 +61,15 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults: ]["unsafe_keras_operators"], ) - def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]: + def _get_keras_h5_operator_names( + self, source: Union[str, Path] + ) -> Optional[List[str]]: # Todo: source isn't guaranteed to be a file with h5py.File(source, "r") as model_hdf5: try: + if not "model_config" in model_hdf5.attrs.keys(): + return None model_config = json.loads(model_hdf5.attrs.get("model_config", {})) layers = model_config.get("config", {}).get("layers", {}) lambda_layers = []