-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
26 lines (22 loc) · 842 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# This is the main file that runs the experiment and recreates plot 1b
import json
from src.baseline_algorithms import MaxEntAllBaseline, RandomAllBaseline
from src.classifier import LogisticRegressionClassifier
from src.plotter import Plotter
from src.seals import SEALSAlgorithm
from src.selection_strategy import MaxEntropySelectionStrategy
if __name__ == "__main__":
classifier = LogisticRegressionClassifier()
selection = MaxEntropySelectionStrategy()
baselines = [MaxEntAllBaseline(), RandomAllBaseline()]
seals = SEALSAlgorithm(
classifier,
selection,
num_classes=30,
random_classes=False,
baseline_algorithms=baselines,
)
scores = seals.run(repetitions=3)
with open("data/results.json", "w") as fp:
json.dump(scores, fp)
Plotter.create_plots(scores)