Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kjysmu authored Nov 3, 2023
1 parent 6d3521f commit 030cfc9
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

from utilities.run_model_vevo import train_epoch, eval_model

# CSV_HEADER = ["Epoch", "Learn rate", "Avg Train loss", "Train Accuracy", "Avg Eval loss", "Eval accuracy"]
CSV_HEADER = ["Epoch", "Learn rate",
"Avg Train loss (total)", "Avg Train loss (chord)", "Avg Train loss (emotion)", "Train Accuracy",
"Avg Eval loss (total)", "Avg Eval loss (chord)", "Avg Eval loss (emotion)", "Eval accuracy"]

# Baseline is an untrained epoch that we evaluate as a baseline loss and accuracy
BASELINE_EPOCH = -1

version = VERSION
Expand All @@ -44,7 +42,6 @@ def main( vm = "" , isPrintArgs = True ):
if isPrintArgs:
print_train_args(args)
if vm != "":
#VIS_MODELS = vm
args.vis_models = vm

if args.is_video:
Expand Down Expand Up @@ -182,7 +179,6 @@ def main( vm = "" , isPrintArgs = True ):

##### TRAIN LOOP #####
for epoch in range(start_epoch, args.epochs):
# Baseline has no training and acts as a base loss and accuracy (epoch 0 in a sense)
if(epoch > BASELINE_EPOCH):
print(SEPERATOR)
print("NEW EPOCH:", epoch+1)
Expand Down Expand Up @@ -224,7 +220,6 @@ def main( vm = "" , isPrintArgs = True ):
eval_h3 = eval_metric_dict["avg_h3"]
eval_h5 = eval_metric_dict["avg_h5"]

# Learn rate
lr = get_lr(opt)

print("Epoch:", epoch+1)
Expand Down

0 comments on commit 030cfc9

Please sign in to comment.