-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_color_bar_classifier.py
29 lines (26 loc) · 1.12 KB
/
train_color_bar_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import argparse
from pathlib import Path
from pprint import pprint
from src.utils.training_config import TrainingConfig
from src.utils.color_bar_model_trainer import ColorBarModelTrainer
from src.utils.common_utils import log
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train color bar classifier.')
parser.add_argument('-c', '--config', dest='config', type=str,
default='fixtures/test_config.json',
help='Path to training config')
args = parser.parse_args()
train_classifier = ColorBarModelTrainer()
train_classifier.init_system(args.config)
log('Training model')
train_classifier.train_model()
log('Validation Evaluation')
train_classifier.validation_evaluation()
validation_incorrect = train_classifier.get_validation_incorrect_results_as_csv()
log(f"Validation Incorrect Results{validation_incorrect}")
log('Testing model')
train_classifier.offline_test()
pprint(train_classifier.get_test_results())
train_classifier.write_test_results()
log(f'Test Incorrect Results\n{train_classifier.get_test_incorrect_results_as_csv()}')