From 579be9fd61759a1948d52cb4f0fea8875e2a9a9c Mon Sep 17 00:00:00 2001 From: Florian Cafiero Date: Wed, 18 Dec 2024 08:41:47 +0100 Subject: [PATCH] Update train_svm.py --- train_svm.py | 47 +++++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/train_svm.py b/train_svm.py index 119569de..ea89c054 100755 --- a/train_svm.py +++ b/train_svm.py @@ -14,41 +14,38 @@ default=None) parser.add_argument('--cross_validate', action='store', help="perform cross validation (test_path will be used only for final prediction)." - "If group_k-fold is chosen, each source file will be considered a group " + " If group_k-fold is chosen, each source file will be considered a group " "(only relevant if sampling was performed and more than one file per class was provided)", default=None, choices=['leave-one-out', 'k-fold', 'group-k-fold'], type=str) parser.add_argument('--k', action='store', help="k for k-fold", default=10, type=int) - parser.add_argument('--dim_reduc', action='store', choices=['pca'], help="optional dimensionality " - "reduction of input data", default=None) + parser.add_argument('--dim_reduc', action='store', choices=['pca'], + help="optional dimensionality reduction of input data", default=None) parser.add_argument('--norms', action='store_true', help="perform normalisations? (default: True)", default=True) parser.add_argument('--balance', action='store', choices=["downsampling", "Tomek", "upsampling", "SMOTE", "SMOTETomek"], - help="which " - "strategy to use in case of imbalanced datasets: " - "downsampling (random without replacement), " - "Tomek (downs. by removing Tomek links), " - # "ENN (EditedNearestNeighbours, downs. by removing samples close to the decision boundary), " - "upsampling (random over sampling with replacement), " - "SMOTE (upsampling with SMOTE - Synthetic Minority Over-sampling Technique), " - "SMOTETomek (over+undersampling with SMOTE+Tomek)", + help="which strategy to use in case of imbalanced datasets: " + "downsampling, Tomek, upsampling, SMOTE, SMOTETomek", default=None) parser.add_argument('--class_weights', action='store_true', - help="whether to use class weights in imbalanced datasets " - "(inversely proportional to total class samples)", - default=False - ) + help="use class weights in imbalanced datasets", + default=False) parser.add_argument('--kernel', action='store', - help="type of kernel to use (default and recommended choice is LinearSVC; " - "possible alternatives are linear, sigmoid, rbf and poly, as per sklearn.svm.SVC)", - default="LinearSVC", choices=['LinearSVC', 'linear', 'sigmoid', 'rbf', 'poly'], type=str) + help="type of kernel to use (default LinearSVC)", default="LinearSVC", + choices=['LinearSVC', 'linear', 'sigmoid', 'rbf', 'poly'], type=str) parser.add_argument('--final', action='store_true', help="final analysis on unknown dataset (no evaluation)?", default=False) parser.add_argument('--get_coefs', action='store_true', - help="switch to write to disk and plot the most important coefficients" - " for the training feats for each class", + help="write to disk and plot most important coefficients", default=False) + # My new arguments for rolling stylometry plotting + parser.add_argument('--plot_rolling', action='store_true', + help="If final predictions are produced, also plot rolling stylometry.") + parser.add_argument('--plot_smoothing', action='store', type=int, default=3, + help="Smoothing window size for rolling stylometry plot (default:3)." + "Set to 0 or None to disable smoothing.") + args = parser.parse_args() print(".......... loading data ........") @@ -56,7 +53,6 @@ if args.test_path is not None: test = pandas.read_csv(args.test_path, index_col=0) - else: test = None @@ -69,7 +65,7 @@ else: args.o = '' - + # Save confusion matrix and misattributions if we have a cross-validate scenario or a test but not final if args.cross_validate is not None or (args.test_path is not None and not args.final): svm["confusion_matrix"].to_csv(args.o+"confusion_matrix.csv") svm["misattributions"].to_csv(args.o+"misattributions.csv") @@ -80,6 +76,13 @@ print(".......... Writing final predictions to " + args.o + "FINAL_PREDICTIONS.csv ........") svm["final_predictions"].to_csv(args.o+"FINAL_PREDICTIONS.csv") + # To get plot for rolling stylometry + if args.plot_rolling: + print(".......... Plotting rolling stylometry ........") + final_pred_path = args.o+"FINAL_PREDICTIONS.csv" + smoothing = args.plot_smoothing if args.plot_smoothing is not None else 0 + superstyl.svm.plot_rolling_stylometry(final_pred_path, smoothing=smoothing) + if args.get_coefs: print(".......... Writing coefficients to disk ........") svm["coefficients"].to_csv(args.o+"coefficients.csv")