-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
34 manual token labeling #40
Changes from 17 commits
36ffb7b
62e9b84
6915061
54b58ce
ca43948
c7be073
b313650
8c07c6f
533ec9c
130418e
7a16cc8
7ef56c3
4797c0c
a49a2f0
be84c4b
010921d
8123fed
4a04e54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,17 @@ | ||
import pickle | ||
from pathlib import Path | ||
|
||
import pytest | ||
import spacy | ||
from spacy.language import Language | ||
from spacy.tokens import Doc | ||
from transformers import AutoTokenizer | ||
|
||
import delphi.eval.token_labelling as tl | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How did you rename the files? If you do right click and rename in VSCode it should also rename all references |
||
|
||
labelled_token_ids_dict: dict[int, dict[str, bool]] = {} | ||
|
||
|
||
@pytest.skip("These tests are slow") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't realize how many there are and that they're all in a single file. I believe you can just do pytest.skip() on the top level, just after the imports to skip the whole file. The reason is "tests are slow and we're not using this module currently" |
||
@pytest.fixture | ||
def dummy_doc() -> tuple[str, Doc, dict[str, bool]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a fixture, not a test, so you probably don't have to skip it |
||
""" | ||
|
@@ -57,6 +63,7 @@ def dummy_doc() -> tuple[str, Doc, dict[str, bool]]: | |
return text, doc, PETER_TOKEN_LABEL | ||
|
||
|
||
@pytest.skip("These tests are slow") | ||
def test_explain_token_labels(dummy_doc): | ||
""" | ||
Test the explain_token_labels function. | ||
|
@@ -68,6 +75,7 @@ def test_explain_token_labels(dummy_doc): | |
tl.explain_token_labels(doc[0]) | ||
|
||
|
||
@pytest.skip("These tests are slow") | ||
def test_label_single_token(dummy_doc): | ||
""" | ||
Test the label_single_token function. | ||
|
@@ -81,6 +89,7 @@ def test_label_single_token(dummy_doc): | |
assert labels == PETER_TOKEN_LABEL | ||
|
||
|
||
@pytest.skip("These tests are slow") | ||
def test_label_sentence(dummy_doc): | ||
""" | ||
Test the label_sentence function. | ||
|
@@ -95,6 +104,7 @@ def test_label_sentence(dummy_doc): | |
assert label == tl.label_single_token(token) | ||
|
||
|
||
@pytest.skip("These tests are slow") | ||
def test_label_batch_sentences(dummy_doc): | ||
""" | ||
Test the label_batch_sentences function. | ||
|
@@ -112,3 +122,70 @@ def test_label_batch_sentences(dummy_doc): | |
# iterate through tokens in doc | ||
for token, label in zip(doc, labels[0]): | ||
assert label == tl.label_single_token(token) | ||
|
||
|
||
@pytest.skip("These tests are slow") | ||
def is_valid_structure(obj: dict[int, dict[str, bool]]) -> bool: | ||
""" | ||
Checks whether the obj fits the structure of `dict[int, dict[str, bool]]`. Returns True, if it fits, False otherwise. | ||
""" | ||
if not isinstance(obj, dict): | ||
print(f"Main structure is not dict! Instead is type {type(obj)}") | ||
return False | ||
for key, value in obj.items(): | ||
if not isinstance(key, int) or not isinstance(value, dict): | ||
print( | ||
f"Main structure is dict, but its keys are either not int or its values are not dicts. Instead key is type {type(key)} and value is type {type(value)}" | ||
) | ||
return False | ||
for sub_key, sub_value in value.items(): | ||
if not isinstance(sub_key, str) or not isinstance(sub_value, bool): | ||
print( | ||
f"The structure dict[int, dict[X, Y]] is True, but either X is not str or Y is not bool. Instead X is type {type(sub_key)} and Y is type {type(sub_value)}" | ||
) | ||
return False | ||
return True | ||
Comment on lines
+127
to
+146
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. beartype is doing this automatically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha, nice. But weirdly enough I needed this to catch a bug because beartype did not throw an error when the type was |
||
|
||
|
||
@pytest.skip("These tests are slow") | ||
def test_label_tokens_from_tokenizer(): | ||
""" | ||
Simple test, checking if download of tokinzer and the labelling of all tokens in its vocabulary works. | ||
""" | ||
global labelled_token_ids_dict | ||
# get a tokinzer | ||
model_name = "delphi-suite/delphi-llama2-100k" | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
vocab_size = tokenizer.vocab_size | ||
|
||
tokens_str, labelled_token_ids_dict = tl.label_tokens_from_tokenizer(tokenizer) | ||
# count the number of lines in the token_str | ||
assert tokens_str.count("\n") == (vocab_size + 1) # + 1, because of token '\n' | ||
assert len(labelled_token_ids_dict.keys()) == vocab_size | ||
assert is_valid_structure(labelled_token_ids_dict) == True | ||
|
||
|
||
@pytest.mark.parametrize("path", [Path("temp/token_labels.csv")]) | ||
def test_import_token_labels(path: Path): | ||
""" | ||
Simple test, checking if the import of token labels works. | ||
|
||
Note: Because we want to use pure pytest and not install any extra dependencies (e.g. pytest-depencency) we recreate the `labelled_tokens_dict` in this test as we did in `test_label_tokens_from_tokenizer`. This duplication is not ideal, but it is the best quick&dirty solution for now. | ||
""" | ||
# create the labelled_token_ids_dict | ||
model_name = "delphi-suite/delphi-llama2-100k" | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
_, labelled_token_ids_dict = tl.label_tokens_from_tokenizer(tokenizer) | ||
|
||
# create the path | ||
path.parent.mkdir(parents=True, exist_ok=True) | ||
# save the file | ||
df = tl.convert_label_dict_to_df(labelled_token_ids_dict) | ||
df.to_csv(path, index=False) | ||
|
||
# load the file with our function to be tested | ||
loaded_dict = tl.import_token_labels(path) | ||
|
||
# assure that the structure is correct | ||
assert loaded_dict == labelled_token_ids_dict | ||
assert is_valid_structure(loaded_dict) == True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file should be renamed to spacy_* as well