Skip to content

Commit

Permalink
Working on #18 : save AnomalyDetectionAlgorithm instance as json, fix…
Browse files Browse the repository at this point in the history
… 🐛 in histogram reading from json input
  • Loading branch information
sam-may committed Feb 4, 2022
1 parent 854f5dd commit b84b107
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
11 changes: 11 additions & 0 deletions autodqm_ml/algorithms/anomaly_detection_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas
import numpy
import awkward
import json

from autodqm_ml import utils
from autodqm_ml.data_formats.histogram import Histogram
Expand Down Expand Up @@ -162,3 +163,13 @@ def save(self):
self.output_file = "%s/%s.parquet" % (self.output_dir, self.input_file.split("/")[-1].replace(".parquet", ""))
logger.info("[AnomalyDetectionAlgorithm : save] Saving output with additional fields to file '%s'." % (self.output_file))
awkward.to_parquet(self.df, self.output_file)

self.config_file = "%s/%s_%s.json" % (self.output_dir, self.name, self.tag)
config = {}
for k,v in vars(self).items():
if utils.is_json_serializable(v):
config[k] = v

logger.info("[AnomalyDetectionAlgorithm : save] Saving AnomalyDetectionAlgorithm config to file '%s'." % (self.config_file))
with open(self.config_file, "w") as f_out:
json.dump(config, f_out, sort_keys = True, indent = 4)
2 changes: 1 addition & 1 deletion autodqm_ml/algorithms/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"early_stopping" : True,
"early_stopping_rounds" : 3,
"n_hidden_layers" : 2,
"n_nodes" : 10,
"n_nodes" : 25,
"n_components" : 3,
"kernel_1d" : 3,
"kernel_2d" : 3,
Expand Down
10 changes: 10 additions & 0 deletions autodqm_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import copy
import subprocess
import json

import logging
from rich.logging import RichHandler
Expand Down Expand Up @@ -146,3 +147,12 @@ def check_proxy():
return proxy


def is_json_serializable(x):
"""
Returns True if `x` is json serializable, False if not
"""
try:
json.dumps(x)
return True
except:
return False
27 changes: 18 additions & 9 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,29 @@
from autodqm_ml.utils import expand_path

parser = argparse.ArgumentParser()

# Required arguments
parser.add_argument(
"--algorithm",
help = "name of algorithm ('PCA' or 'Autoencoder' or 'StatisticalTester') to train with default options OR path to json filed specifying particular options for training a given algorithm.",
type = str,
required = True
)

# Optional arguments
parser.add_argument(
"--output_dir",
help = "output directory to place files in",
type = str,
required = False,
default = "output"
default = None
)
parser.add_argument(
"--tag",
help = "tag to identify output files",
type = str,
required = False,
default = "test"
)
parser.add_argument(
"--algorithm",
help = "name of algorithm ('PCA' or 'Autoencoder' or 'StatisticalTester') to train with default options OR path to json filed specifying particular options for training a given algorithm.",
type = str,
required = True
default = None
)
parser.add_argument(
"--input_file",
Expand Down Expand Up @@ -126,8 +130,13 @@

if args.histograms is not None:
histograms = {x : { "normalize" : True} for x in args.histograms.split(",")}
else:
elif isinstance(config["histograms"], str):
histograms = {x : { "normalize" : True} for x in config["histograms"].split(",")}
elif isinstance(config["histograms"], dict):
histograms = config["histograms"]
else:
logger.exception("[train.py] The `histograms` argument should either be a csv list of histogram names (str) or a dictionary (if provided through a json config).")
raise RuntimeError()

# Load data
algorithm.load_data(
Expand Down

0 comments on commit b84b107

Please sign in to comment.