Skip to content

Commit

Permalink
add test_import_token_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Feb 28, 2024
1 parent 130418e commit 7a16cc8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tqdm==4.66.1
ipywidgets==8.1.1
nbformat==5.9.2
pytest==7.4.4
pytest-dependency==0.6.0
black==23.12.1
jaxtyping==0.2.25
beartype==0.16.4
Expand Down
52 changes: 43 additions & 9 deletions tests/eval/test_token_labelling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from pathlib import Path

import pytest
Expand All @@ -7,6 +8,8 @@

import delphi.eval.token_labelling as tl

labelled_token_ids_dict: dict[int, dict[str, bool]] = {}


@pytest.fixture
def dummy_doc() -> tuple[str, Doc, dict[str, bool]]:
Expand Down Expand Up @@ -121,37 +124,68 @@ def is_valid_structure(obj: dict[int, dict[str, bool]]) -> bool:
Checks whether the obj fits the structure of `dict[int, dict[str, bool]]`. Returns True, if it fits, False otherwise.
"""
if not isinstance(obj, dict):
print(f"Main structure is not dict! Instead is type {type(obj)}")
return False
for key, value in obj.items():
if not isinstance(key, int) or not isinstance(value, dict):
print(
f"Main structure is dict, but its keys are either not int or its values are not dicts. Instead key is type {type(key)} and value is type {type(value)}"
)
return False
for sub_key, sub_value in value.items():
if not isinstance(sub_key, str) or not isinstance(sub_value, bool):
print(
f"The structure dict[int, dict[X, Y]] is True, but either X is not str or Y is not bool. Instead X is type {type(sub_key)} and Y is type {type(sub_value)}"
)
return False
return True


@pytest.mark.dependency()
def test_label_tokens_from_tokenizer():
"""
Simple test, checking if download of tokinzer and the labelling of all tokens in its vocabulary works.
"""
global labelled_token_ids_dict
# get a tokinzer
model_name = "delphi-suite/delphi-llama2-100k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab_size = tokenizer.vocab_size

tokens_str, labelled_token_ids_dict = tl.label_tokens_from_tokenizer(tokenizer)
# count the number of lines in the token_str
assert tokens_str.count("\n") == (
vocab_size + 1
) # each token is on one line + 1 token is '\n'
assert tokens_str.count("\n") == (vocab_size + 1) # + 1, because of token '\n'
assert len(labelled_token_ids_dict.keys()) == vocab_size
assert is_valid_structure(labelled_token_ids_dict) == True


# @pytest.mark.parametrize("path", [Path("fsdfsdfsf"), "urghurigh"])
# def test_import_token_labels(path):

# labelled_token_ids_dict = tl.import_token_labels(path)

# # assure that the structure is correct
@pytest.mark.dependency(depends=["test_label_tokens_from_tokenizer"])
@pytest.mark.parametrize(
"path", [Path("temp/token_labels.csv"), Path("temp/token_labels.pkl")]
)
def test_import_token_labels(path: Path):
global labelled_token_ids_dict
assert (
labelled_token_ids_dict is not None
), "It should be filled for the test to run. Check test-dependency."
assert (
labelled_token_ids_dict != {}
), "It should be filled for the test to run. Check test-dependency."
# create the path
path.parent.mkdir(parents=True, exist_ok=True)
# save the file
if path.suffix == ".pkl":
with open(path, "wb") as file:
pickle.dump(labelled_token_ids_dict, file)
elif path.suffix == ".csv":
df = tl.convert_label_dict_to_df(labelled_token_ids_dict)
df.to_csv(path, index=False)
else:
raise ValueError("The file ending is incorrect.")

# load the file with our function to be tested
loaded_dict = tl.import_token_labels(path)

# assure that the structure is correct
assert loaded_dict == labelled_token_ids_dict
assert is_valid_structure(loaded_dict) == True

0 comments on commit 7a16cc8

Please sign in to comment.