Skip to content

Commit

Permalink
Update test_train_svm.py
Browse files Browse the repository at this point in the history
Test for the rolling plot fonction
  • Loading branch information
floriancafiero authored Dec 18, 2024
1 parent aaef4dc commit 4f81d65
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/test_train_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import superstyl
import os
import pandas
import tempfile

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -102,4 +103,26 @@ def test_train_svm(self):

# This is only the first minimal tests for this function

def test_plot_rolling_stylometry(self):
train = pandas.DataFrame({
'author': {'Text_0-1000': 'A', 'Text_1000-2000': 'A', 'Text_2000-3000': 'B'},
'lang': {'Text_0-1000': 'NA', 'Text_1000-2000': 'NA', 'Text_2000-3000': 'NA'},
'word1': {'Text_0-1000': 0.5, 'Text_1000-2000': 0.5, 'Text_2000-3000': 0.0},
'word2': {'Text_0-1000': 0.0, 'Text_1000-2000': 0.5, 'Text_2000-3000': 0.5},
'word3': {'Text_0-1000': 0.5, 'Text_1000-2000': 0.0, 'Text_2000-3000': 0.5}
})
test = train.copy()
results = superstyl.train_svm(train, test, final_pred=True)
final_preds = results["final_predictions"]

with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmpfile:
temp_path = tmpfile.name
final_preds.to_csv(temp_path, index=False)

try:
superstyl.plot_rolling_stylometry(temp_path, smoothing=3)
except Exception as e:
self.fail(f"plot_rolling_stylometry raised an exception: {e}")
finally:
if os.path.exists(temp_path):
os.remove(temp_path)

0 comments on commit 4f81d65

Please sign in to comment.