From 75f797f6f06410a5f86b1b6b7899b96f37e8b9c0 Mon Sep 17 00:00:00 2001 From: Ankur Chattopadhyay <39518771+chttrjeankr@users.noreply.github.com> Date: Mon, 6 Apr 2020 13:52:51 +0530 Subject: [PATCH] unit tests for utilities.py (#119) --- test_utilities.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++ utilities.py | 8 ++++--- 2 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 test_utilities.py diff --git a/test_utilities.py b/test_utilities.py new file mode 100644 index 0000000..f462f71 --- /dev/null +++ b/test_utilities.py @@ -0,0 +1,57 @@ +import utilities + + +class TestClass: + test_input = "The quick brown fox jumps over the lazy dog." + clf = utilities.classify_model() + + def test_setup_nltk(self): + result = utilities.setup_nltk() + assert result + + def test_parse_sentence(self): + triples, root = utilities.parse_sentence(self.test_input) + triples = list(triples) + assert (("jumps", "VBZ"), "nsubj", ("fox", "NN")) in triples + assert (("jumps", "VBZ"), "nmod", ("dog", "NN")) in triples + assert root == "jumps" + + def test_classify_model(self): + from features import features_dict + import hashlib + import numpy as np + + keys = [ + "id", + "wordCount", + "stemmedCount", + "stemmedEndNN", + "CD", + "NN", + "NNP", + "NNPS", + "NNS", + "PRP", + "VBG", + "VBZ", + "startTuple0", + "endTuple0", + "endTuple1", + "endTuple2", + "verbBeforeNoun", + "qMark", + "qVerbCombo", + "qTripleScore", + "sTripleScore", + "class", + ] + id = hashlib.md5(str(self.test_input).encode("utf-8")).hexdigest()[:16] + f = features_dict(id, self.test_input) + features = [f[k] for k in keys][1:-1] + features = np.array(features).reshape(1, -1) + + assert self.clf.predict(features)[0] == "S" + + def test_classify_sentence(self): + result = utilities.classify_sentence(self.clf, self.test_input) + assert result == "S" diff --git a/utilities.py b/utilities.py index de94347..a35a6b3 100644 --- a/utilities.py +++ b/utilities.py @@ -10,9 +10,11 @@ def setup_nltk(): import nltk - nltk.download("punkt") - nltk.download("averaged_perceptron_tagger") - nltk.download("stopwords") + punkt = nltk.download("punkt") + averaged_perceptron_tagger = nltk.download("averaged_perceptron_tagger") + stopwords = nltk.download("stopwords") + + return all((punkt, averaged_perceptron_tagger, stopwords)) @logger_config.logger