Skip to content

Commit

Permalink
Add implementation specific tests for Dataset/Model/Task (#205)
Browse files Browse the repository at this point in the history
This commit adds tests that ensure the base class constructor is called when creating any dataset/model/task object, and that it is called with the correct arguments. Also includes minor fixes to existing tests.

* Fix Typo in test class

* Fix escape warning

* Fix missing subtest context

* Add tests for base class constructor calls

* Refactor common code to test.utils
  • Loading branch information
fdalvi authored Sep 7, 2023
1 parent 6e20b54 commit dcf1074
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 17 deletions.
2 changes: 1 addition & 1 deletion llmebench/datasets/WANLP22T3Propaganda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def metadata():
return {
"language": "ar",
"citation": """@inproceedings{alam2022overview,
title={Overview of the $\{$WANLP$\}$ 2022 Shared Task on Propaganda Detection in $\{$A$\}$ rabic},
title={Overview of the $\\{$WANLP$\\}$ 2022 Shared Task on Propaganda Detection in $\\{$A$\\}$ rabic},
author={Alam, Firoj and Mubarak, Hamdy and Zaghouani, Wajdi and Da San Martino, Giovanni and Nakov, Preslav and others},
booktitle={Proceedings of the The Seventh Arabic Natural Language Processing Workshop (WANLP)},
pages={108--118},
Expand Down
25 changes: 25 additions & 0 deletions tests/datasets/test_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import inspect
import unittest

from pathlib import Path

import llmebench.datasets as datasets

from tests.utils import base_class_constructor_checker


class TestDatasetImplementation(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Search for all implemented datasets
framework_dir = Path("llmebench")
cls.datasets = set(
[m[1] for m in inspect.getmembers(datasets, inspect.isclass)]
)

def test_base_constructor(self):
"Test if all datasets also call the base class constructor"

for dataset in self.datasets:
with self.subTest(msg=dataset.__name__):
base_class_constructor_checker(dataset, self)
31 changes: 16 additions & 15 deletions tests/datasets/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ def test_dataset_metadata(self):
"Test if all datasets export the required metadata"

for dataset in self.datasets:
self.assertIsInstance(dataset.metadata(), dict)
self.assertIn("citation", dataset.metadata())
self.assertIsInstance(dataset.metadata()["citation"], str)
self.assertIn("language", dataset.metadata())
self.assertIsInstance(dataset.metadata()["language"], (str, list))

languages = dataset.metadata()["language"]
if isinstance(languages, str):
languages = [languages]

for language in languages:
self.assertTrue(
language == "multilingual" or tag_is_valid(language),
f"{language} is not a valid language",
)
with self.subTest(msg=dataset.__name__):
self.assertIsInstance(dataset.metadata(), dict)
self.assertIn("citation", dataset.metadata())
self.assertIsInstance(dataset.metadata()["citation"], str)
self.assertIn("language", dataset.metadata())
self.assertIsInstance(dataset.metadata()["language"], (str, list))

languages = dataset.metadata()["language"]
if isinstance(languages, str):
languages = [languages]

for language in languages:
self.assertTrue(
language == "multilingual" or tag_is_valid(language),
f"{language} is not a valid language",
)
2 changes: 1 addition & 1 deletion tests/models/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llmebench.models.model_base import ModelBase


class TestDatasetExports(unittest.TestCase):
class TestModelExports(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Search for all implemented models
Expand Down
23 changes: 23 additions & 0 deletions tests/models/test_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import inspect
import unittest

from pathlib import Path

import llmebench.models as models

from tests.utils import base_class_constructor_checker


class TestModelImplementation(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Search for all implemented models
framework_dir = Path("llmebench")
cls.models = set([m[1] for m in inspect.getmembers(models, inspect.isclass)])

def test_base_constructor(self):
"Test if all models also call the base class constructor"

for model in self.models:
with self.subTest(msg=model.__name__):
base_class_constructor_checker(model, self)
23 changes: 23 additions & 0 deletions tests/tasks/test_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import inspect
import unittest

from pathlib import Path

import llmebench.tasks as tasks

from tests.utils import base_class_constructor_checker


class TestTaskImplementation(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Search for all implemented models
framework_dir = Path("llmebench")
cls.tasks = set([m[1] for m in inspect.getmembers(tasks, inspect.isclass)])

def test_base_constructor(self):
"Test if all tasks also call the base class constructor"

for task in self.tasks:
with self.subTest(msg=task.__name__):
base_class_constructor_checker(task, self)
51 changes: 51 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import ast
import inspect


def base_class_constructor_checker(class_implementation, tester):
tree = ast.parse(inspect.getsource(class_implementation))
constructors = list(
n
for n in ast.walk(tree)
if isinstance(n, ast.FunctionDef) and n.name == "__init__"
)
tester.assertLessEqual(len(constructors), 1, "Multiple constructors found")

if len(constructors) == 0:
# No constructor, base will be called by default
return

constructor = constructors[0]

# Collect all function calls inside the constructor
fn_calls = list(
n
for n in ast.walk(constructor)
if isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute)
)

# For each, check if there is something of the form super(...).__init(...)
def is_base_constructor_call(node):
fn_call = node.func
if not isinstance(fn_call.value, ast.Call):
return False
if not fn_call.value.func.id == "super":
return False
if not fn_call.attr == "__init__":
return False

return True

filtered_fn_calls = list(filter(is_base_constructor_call, fn_calls))

tester.assertEqual(
len(filtered_fn_calls), 1, "Call to base class constructor missing"
)

tester.assertTrue(
any(
isinstance(k.value, ast.Name) and k.value.id == "kwargs"
for k in filtered_fn_calls[0].keywords
),
"kwargs not passed to the base class constructor",
)

0 comments on commit dcf1074

Please sign in to comment.