From ba37aa46bc247496251fa32506d95e55be0569de Mon Sep 17 00:00:00 2001 From: Ahmed Besbes Date: Fri, 1 Nov 2019 15:33:14 +0100 Subject: [PATCH] add log_f1 argument --- train.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index fbb5548..49867a6 100644 --- a/train.py +++ b/train.py @@ -78,13 +78,6 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi lr = optimizer.state_dict()["param_groups"][0]["lr"] if (iter % print_every == 0) and (iter > 0): - intermediate_report = classification_report( - y_true, y_pred, output_dict=True) - - f1_by_class = 'F1 Scores by class: ' - for class_name in class_names: - f1_by_class += f"{class_name} : {np.round(intermediate_report[class_name]['f1-score'], 4)} |" - print("[Training - Epoch: {}], LR: {} , Iteration: {}/{} , Loss: {}, Accuracy: {}".format( epoch + 1, lr, @@ -93,7 +86,16 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi losses.avg, accuracies.avg )) - print(f1_by_class) + + if bool(args.log_f1): + intermediate_report = classification_report( + y_true, y_pred, output_dict=True) + + f1_by_class = 'F1 Scores by class: ' + for class_name in class_names: + f1_by_class += f"{class_name} : {np.round(intermediate_report[class_name]['f1-score'], 4)} |" + + print(f1_by_class) f1_train = f1_score(y_true, y_pred, average='weighted') @@ -403,6 +405,7 @@ def run(args, both_cases=False): parser.add_argument('--workers', type=int, default=1) parser.add_argument('--log_path', type=str, default='./logs/') parser.add_argument('--log_every', type=int, default=100) + parser.add_argument('--log_f1', type=int, default=1, choices=[0, 1]) parser.add_argument('--flush_history', type=int, default=1, choices=[0, 1]) parser.add_argument('--output', type=str, default='./models/')