diff --git a/bin/train_and_predict_final.py b/bin/train_and_predict_final.py index a590391..f5f037d 100755 --- a/bin/train_and_predict_final.py +++ b/bin/train_and_predict_final.py @@ -209,7 +209,7 @@ def compute_cross( predictions_path, f"predictions_{args.split_id}.csv", ) - test_set.save(prediction_dataset) + test_set.to_csv(prediction_dataset) for ds in args.cross_study_datasets: if ds == "NONE.csv": continue