From 4f81d659fa28a2812bb7dc776ce56c4765e34853 Mon Sep 17 00:00:00 2001 From: Florian Cafiero Date: Wed, 18 Dec 2024 15:41:29 +0100 Subject: [PATCH] Update test_train_svm.py Test for the rolling plot fonction --- tests/test_train_svm.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_train_svm.py b/tests/test_train_svm.py index e5335e96..23473f76 100644 --- a/tests/test_train_svm.py +++ b/tests/test_train_svm.py @@ -2,6 +2,7 @@ import superstyl import os import pandas +import tempfile THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -102,4 +103,26 @@ def test_train_svm(self): # This is only the first minimal tests for this function + def test_plot_rolling_stylometry(self): + train = pandas.DataFrame({ + 'author': {'Text_0-1000': 'A', 'Text_1000-2000': 'A', 'Text_2000-3000': 'B'}, + 'lang': {'Text_0-1000': 'NA', 'Text_1000-2000': 'NA', 'Text_2000-3000': 'NA'}, + 'word1': {'Text_0-1000': 0.5, 'Text_1000-2000': 0.5, 'Text_2000-3000': 0.0}, + 'word2': {'Text_0-1000': 0.0, 'Text_1000-2000': 0.5, 'Text_2000-3000': 0.5}, + 'word3': {'Text_0-1000': 0.5, 'Text_1000-2000': 0.0, 'Text_2000-3000': 0.5} + }) + test = train.copy() + results = superstyl.train_svm(train, test, final_pred=True) + final_preds = results["final_predictions"] + + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmpfile: + temp_path = tmpfile.name + final_preds.to_csv(temp_path, index=False) + try: + superstyl.plot_rolling_stylometry(temp_path, smoothing=3) + except Exception as e: + self.fail(f"plot_rolling_stylometry raised an exception: {e}") + finally: + if os.path.exists(temp_path): + os.remove(temp_path)