-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add implementation specific tests for Dataset/Model/Task (#205)
This commit adds tests that ensure the base class constructor is called when creating any dataset/model/task object, and that it is called with the correct arguments. Also includes minor fixes to existing tests. * Fix Typo in test class * Fix escape warning * Fix missing subtest context * Add tests for base class constructor calls * Refactor common code to test.utils
- Loading branch information
Showing
7 changed files
with
140 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import inspect | ||
import unittest | ||
|
||
from pathlib import Path | ||
|
||
import llmebench.datasets as datasets | ||
|
||
from tests.utils import base_class_constructor_checker | ||
|
||
|
||
class TestDatasetImplementation(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
# Search for all implemented datasets | ||
framework_dir = Path("llmebench") | ||
cls.datasets = set( | ||
[m[1] for m in inspect.getmembers(datasets, inspect.isclass)] | ||
) | ||
|
||
def test_base_constructor(self): | ||
"Test if all datasets also call the base class constructor" | ||
|
||
for dataset in self.datasets: | ||
with self.subTest(msg=dataset.__name__): | ||
base_class_constructor_checker(dataset, self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import inspect | ||
import unittest | ||
|
||
from pathlib import Path | ||
|
||
import llmebench.models as models | ||
|
||
from tests.utils import base_class_constructor_checker | ||
|
||
|
||
class TestModelImplementation(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
# Search for all implemented models | ||
framework_dir = Path("llmebench") | ||
cls.models = set([m[1] for m in inspect.getmembers(models, inspect.isclass)]) | ||
|
||
def test_base_constructor(self): | ||
"Test if all models also call the base class constructor" | ||
|
||
for model in self.models: | ||
with self.subTest(msg=model.__name__): | ||
base_class_constructor_checker(model, self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import inspect | ||
import unittest | ||
|
||
from pathlib import Path | ||
|
||
import llmebench.tasks as tasks | ||
|
||
from tests.utils import base_class_constructor_checker | ||
|
||
|
||
class TestTaskImplementation(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
# Search for all implemented models | ||
framework_dir = Path("llmebench") | ||
cls.tasks = set([m[1] for m in inspect.getmembers(tasks, inspect.isclass)]) | ||
|
||
def test_base_constructor(self): | ||
"Test if all tasks also call the base class constructor" | ||
|
||
for task in self.tasks: | ||
with self.subTest(msg=task.__name__): | ||
base_class_constructor_checker(task, self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import ast | ||
import inspect | ||
|
||
|
||
def base_class_constructor_checker(class_implementation, tester): | ||
tree = ast.parse(inspect.getsource(class_implementation)) | ||
constructors = list( | ||
n | ||
for n in ast.walk(tree) | ||
if isinstance(n, ast.FunctionDef) and n.name == "__init__" | ||
) | ||
tester.assertLessEqual(len(constructors), 1, "Multiple constructors found") | ||
|
||
if len(constructors) == 0: | ||
# No constructor, base will be called by default | ||
return | ||
|
||
constructor = constructors[0] | ||
|
||
# Collect all function calls inside the constructor | ||
fn_calls = list( | ||
n | ||
for n in ast.walk(constructor) | ||
if isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute) | ||
) | ||
|
||
# For each, check if there is something of the form super(...).__init(...) | ||
def is_base_constructor_call(node): | ||
fn_call = node.func | ||
if not isinstance(fn_call.value, ast.Call): | ||
return False | ||
if not fn_call.value.func.id == "super": | ||
return False | ||
if not fn_call.attr == "__init__": | ||
return False | ||
|
||
return True | ||
|
||
filtered_fn_calls = list(filter(is_base_constructor_call, fn_calls)) | ||
|
||
tester.assertEqual( | ||
len(filtered_fn_calls), 1, "Call to base class constructor missing" | ||
) | ||
|
||
tester.assertTrue( | ||
any( | ||
isinstance(k.value, ast.Name) and k.value.id == "kwargs" | ||
for k in filtered_fn_calls[0].keywords | ||
), | ||
"kwargs not passed to the base class constructor", | ||
) |