diff --git a/notebooks/token_labelling.ipynb b/notebooks/token_labelling.ipynb new file mode 100644 index 00000000..45423d8c --- /dev/null +++ b/notebooks/token_labelling.ipynb @@ -0,0 +1,435 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Giving tokens a label - How to categorize tokens\n", + "\n", + "\n", + "The first part of this Notebook contains elements that explain how to label tokens and how the functions work.\n", + "\n", + "The second part shows how all tokens are labelled that are used for our delphi language models.3\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "# autoreload\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from pprint import pprint \n", + "\n", + "import spacy\n", + "from tqdm.auto import tqdm\n", + "\n", + "import delphi\n", + "\n", + "from delphi.eval import token_labelling" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "# 1) How to use the token labelling functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We analyze a simple sentence and receive the respective tokens with their analyzed attributes. \n", + "The grammatical/linguistic analysis is done by a model provided by spaCy for the English language." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Peter \t PROPN \t nsubj \t PERSON\n", + "is \t AUX \t ROOT \t \n", + "a \t DET \t det \t \n", + "person \t NOUN \t attr \t \n" + ] + } + ], + "source": [ + "# Load the english model\n", + "nlp = spacy.load(\"en_core_web_sm\")\n", + "\n", + "# Create a Doc object from a given text\n", + "doc = nlp(\"Peter is a person\")\n", + "\n", + "token = doc[0]\n", + "for tok in doc:\n", + " print(tok,\"\\t\", tok.pos_, \"\\t\", tok.dep_, \"\\t\", tok.ent_type_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's get the label for our custom token that we just printed." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'Capitalized': True,\n", + " 'Is Adjective': False,\n", + " 'Is Adposition': False,\n", + " 'Is Adverb': False,\n", + " 'Is Auxiliary': False,\n", + " 'Is Coordinating conjuction': False,\n", + " 'Is Determiner': False,\n", + " 'Is Interjunction': False,\n", + " 'Is Named Entity': True,\n", + " 'Is Noun': False,\n", + " 'Is Numeral': False,\n", + " 'Is Other': False,\n", + " 'Is Particle': False,\n", + " 'Is Pronoun': False,\n", + " 'Is Proper Noun': True,\n", + " 'Is Punctuation': False,\n", + " 'Is Subordinating conjuction': False,\n", + " 'Is Symbol': False,\n", + " 'Is Verb': False,\n", + " 'Starts with space': False}\n" + ] + } + ], + "source": [ + "from delphi.eval import token_labelling\n", + "\n", + "label = token_labelling.label_single_token(token)\n", + "pprint(label)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's get an understanding of what the labels acutally mean.\n", + "Use this function to receive an explanation for a single token." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-------- Explanation of token labels --------\n", + "Token text: Peter\n", + "Token dependency: nominal subject\n", + "Token POS: proper noun\n", + "---------------- Token labels ---------------\n", + " 0 Starts with space False\n", + " 1 Capitalized True\n", + " 2 Is Adjective False\n", + " 3 Is Adposition False\n", + " 4 Is Adverb False\n", + " 5 Is Auxiliary False\n", + " 6 Is Coordinating conjuction False\n", + " 7 Is Determiner False\n", + " 8 Is Interjunction False\n", + " 9 Is Noun False\n", + " 10 Is Numeral False\n", + " 11 Is Particle False\n", + " 12 Is Pronoun False\n", + " 13 Is Proper Noun True\n", + " 14 Is Punctuation False\n", + " 15 Is Subordinating conjuction False\n", + " 16 Is Symbol False\n", + " 17 Is Verb False\n", + " 18 Is Other False\n", + " 19 Is Named Entity True\n" + ] + } + ], + "source": [ + "token_labelling.explain_token_labels(token)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you are interested in all the possible labels a token can have, that spaCy is capable of assigning, then call the same function but without any argument:\n", + "```Python\n", + ">>> token_labelling.explain_token_labels()\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batched token labelling\n", + "Next, let us analyze a batch of sentences and have them labelled.\n", + "> In the example below the input sentences are not yet tokenized, so spaCy uses its internal tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: Peter\n", + "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n", + "False | True | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | True \n", + "---\n", + "Token: is\n", + "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n", + "False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False \n", + "---\n", + "Token: a\n", + "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n", + "False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False \n", + "---\n", + "Token: person\n", + "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n", + "False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False \n", + "---\n", + "Token: .\n", + "Starts with space | Capitalized | Is Adjective | Is Adposition | Is Adverb | Is Auxiliary | Is Coordinating conjuction | Is Determiner | Is Interjunction | Is Noun | Is Numeral | Is Particle | Is Pronoun | Is Proper Noun | Is Punctuation | Is Subordinating conjuction | Is Symbol | Is Verb | Is Other | Is Named Entity\n", + "False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False \n", + "---\n", + "\n", + "\n", + "5\n", + "[{'Starts with space': False, 'Capitalized': True, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': True, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': True}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': True, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': True, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': True, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': False, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Adjective': False, 'Is Adposition': False, 'Is Adverb': False, 'Is Auxiliary': False, 'Is Coordinating conjuction': False, 'Is Determiner': False, 'Is Interjunction': False, 'Is Noun': False, 'Is Numeral': False, 'Is Particle': False, 'Is Pronoun': False, 'Is Proper Noun': False, 'Is Punctuation': True, 'Is Subordinating conjuction': False, 'Is Symbol': False, 'Is Verb': False, 'Is Other': False, 'Is Named Entity': False}]\n" + ] + } + ], + "source": [ + "sentences = [\n", + " \"Peter is a person.\"\n", + "]\n", + "labels = token_labelling.label_batch_sentences(sentences, tokenized=False, verbose=True)\n", + "\n", + "print(len(labels[0]))\n", + "print(labels[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now with our own tokenization. E.g. the one from our TinyStories models." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5\n", + "[{'Starts with space': False, 'Capitalized': True, 'Is Noun': True, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': False, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': True, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': False, 'Is Pronoun': False, 'Is Adjective': True, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': True, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}, {'Starts with space': False, 'Capitalized': False, 'Is Noun': False, 'Is Pronoun': False, 'Is Adjective': False, 'Is Verb': False, 'Is Adverb': False, 'Is Preposition': False, 'Is Conjunction': False, 'Is Interjunction': False, 'Is Named Entity': False}]\n" + ] + } + ], + "source": [ + "sentences = [\n", + " [\"This \", \"is \", \"a \", \"sentence\", \".\"]\n", + "]\n", + "labelled_sentences = token_labelling.label_batch_sentences(sentences, tokenized=True, verbose=False)\n", + "\n", + "print(len(labelled_sentences[0]))\n", + "print(labelled_sentences[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2) Labelling all tokens in the dataset\n", + "\n", + "Now we want to label all the tokens that our tokenizer knows - its entire vocabulary.\n", + "\n", + "Using thy script in `scripts/label_all_tokens.py` we get the files:\n", + "- `src\\delphi\\eval\\all_tokens_list.txt`\n", + "- `src\\delphi\\eval\\labelled_token_ids_dict.pkl`\n", + "\n", + "Let's load the tokenizer so that we can look at the labelled tokens.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\joshu\\anaconda3\\envs\\delphi2\\lib\\site-packages\\transformers\\utils\\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The vocab size is: 4096\n" + ] + } + ], + "source": [ + "# Get all the tokens of the tokenizer\n", + "from transformers import AutoTokenizer, PreTrainedTokenizer\n", + "\n", + "\n", + "# Decode a sentence\n", + "def decode(tokenizer: PreTrainedTokenizer, token_ids: list[int]) -> str:\n", + " return tokenizer.decode(token_ids, skip_special_tokens=True)\n", + "\n", + "model = \"delphi-suite/delphi-llama2-100k\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model)\n", + "vocab_size = tokenizer.vocab_size\n", + "print(\"The vocab size is:\", vocab_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the pickle." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "path = \"../src/delphi/eval/labelled_token_ids_dict.pkl\"\n", + "# load \n", + "with open(path, \"rb\") as f:\n", + " labelled_token_ids_dict = pickle.load(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Look at some random tokens and their labels" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The token id is: 1143\n", + "The decoded token is: has\n", + "The label is:\n", + "{'Capitalized': False,\n", + " 'Is Adjective': False,\n", + " 'Is Adposition': False,\n", + " 'Is Adverb': False,\n", + " 'Is Auxiliary': False,\n", + " 'Is Coordinating conjuction': False,\n", + " 'Is Determiner': False,\n", + " 'Is Interjunction': True,\n", + " 'Is Named Entity': False,\n", + " 'Is Noun': False,\n", + " 'Is Numeral': False,\n", + " 'Is Other': False,\n", + " 'Is Particle': False,\n", + " 'Is Pronoun': False,\n", + " 'Is Proper Noun': False,\n", + " 'Is Punctuation': False,\n", + " 'Is Subordinating conjuction': False,\n", + " 'Is Symbol': False,\n", + " 'Is Verb': False,\n", + " 'Starts with space': False}\n" + ] + } + ], + "source": [ + "import random\n", + "from pprint import pprint\n", + "# Get a random token id between 0 and 4000\n", + "token_id = random.randint(0, 4000)\n", + "# decode the token id\n", + "decoded_token = decode(tokenizer, [token_id])\n", + "# get the corresponding label\n", + "label = labelled_token_ids_dict[token_id]\n", + "# print the results\n", + "print(\"The token id is:\", token_id)\n", + "print(\"The decoded token is:\", decoded_token)\n", + "print(\"The label is:\")\n", + "pprint(label)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv_tinyevals", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 65b457a4..5fdc84c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,8 @@ black==23.12.1 jaxtyping==0.2.25 beartype==0.16.4 pre-commit==3.6.0 -isort==5.13.2 \ No newline at end of file +isort==5.13.2 +spacy==3.7.2 +chardet==5.2.0 +sentencepiece==0.1.99 +protobuf==4.25.2 \ No newline at end of file diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py new file mode 100644 index 00000000..01bf4cf1 --- /dev/null +++ b/scripts/label_all_tokens.py @@ -0,0 +1,110 @@ +import argparse +import pickle +from pathlib import Path + +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from delphi.eval import token_labelling + + +def tokenize( + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, sample_txt: str +) -> int: + # supposedly this can be different than prepending the bos token id + return tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0] + + +# Decode a sentence +def decode( + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, token_ids: int | list[int] +) -> str: + return tokenizer.decode(token_ids, skip_special_tokens=True) + + +def main(): + # Setup argparse + parser = argparse.ArgumentParser(description="Tokenization and labeling utility.") + parser.add_argument( + "--model_name", + type=str, + help="Name of the model to use for tokenization and labeling.", + default="delphi-suite/delphi-llama2-100k", + required=False, + ) + args = parser.parse_args() + + # Access command-line arguments + # Directory to save the results + SAVE_DIR = Path("src/delphi/eval/") + model_name = args.model_name + + print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n") + print(f"You chose the model: {model_name}\n") + print( + f"The language model will be loaded from Huggingface and its tokenizer used to do two things:\n\t1) Create a list of all tokens in the tokenizer's vocabulary.\n\t2) Label each token with its part of speech, dependency, and named entity recognition tags.\nThe respective results will be saved to files located at: '{SAVE_DIR}'\n" + ) + + # ================ (1) ================= + print("(1) Create a list of all tokens in the tokenizer's vocabulary ...") + + # Load the tokenizer from Huggingface + tokenizer = AutoTokenizer.from_pretrained(model_name) + vocab_size = tokenizer.vocab_size + print("Loaded the tokenizer.\nThe vocab size is:", vocab_size) + + # Create a list of all tokens in the tokenizer's vocabulary + tokens_str = "" # will hold all tokens and their ids + for i in range(tokenizer.vocab_size): + tokens_str += f"{i},{decode(tokenizer, i)}\n" + + # Save the list of all tokens to a file + filename = "all_tokens_list.txt" + filepath = SAVE_DIR / filename + with open(filepath, "w", encoding="utf-8") as f: + f.write(tokens_str) + + print(f"Saved the list of all tokens to:\n\t{filepath}\n") + + # ================ (2) ================= + print("(2) Label each token ...") + + # let's label each token + labelled_token_ids_dict: dict[int, dict[str, bool]] = {} # token_id: labels + max_token_id = tokenizer.vocab_size # stop at which token id, vocab size + # we iterate over all token_ids individually + for token_id in tqdm(range(0, max_token_id), desc="Labelling tokens"): + # decode the token_ids to get a list of tokens, a 'sentence' + tokens = decode(tokenizer, token_id) # list of tokens == sentence + # put the sentence into a list, to make it a batch of sentences + sentences = [tokens] + # label the batch of sentences + labels = token_labelling.label_batch_sentences( + sentences, tokenized=True, verbose=False + ) + # create a dict with the token_ids and their labels + # update the labelled_token_ids_dict with the new dict + labelled_token_ids_dict[token_id] = labels[0][0] + + # Save the labelled tokens to a file + filename = "labelled_token_ids_dict.pkl" + filepath = SAVE_DIR / filename + with open(filepath, "wb") as f: + pickle.dump(labelled_token_ids_dict, f) + + print(f"Saved the labelled tokens to:\n\t{filepath}\n") + + # sanity check that The pickled and the original dict are the same + print("Sanity check ...", end="") + # load pickle + with open(filepath, "rb") as f: + pickled = pickle.load(f) + # compare + assert labelled_token_ids_dict == pickled + print(" completed.") + + print(" END ".center(50, "=")) + + +if __name__ == "__main__": + main() diff --git a/src/delphi/eval/all_tokens_list.txt b/src/delphi/eval/all_tokens_list.txt new file mode 100644 index 00000000..438dddae Binary files /dev/null and b/src/delphi/eval/all_tokens_list.txt differ diff --git a/src/delphi/eval/labelled_token_ids_dict.pkl b/src/delphi/eval/labelled_token_ids_dict.pkl new file mode 100644 index 00000000..5fe96a39 Binary files /dev/null and b/src/delphi/eval/labelled_token_ids_dict.pkl differ diff --git a/src/delphi/eval/token_labelling.py b/src/delphi/eval/token_labelling.py new file mode 100644 index 00000000..80673e03 --- /dev/null +++ b/src/delphi/eval/token_labelling.py @@ -0,0 +1,210 @@ +from typing import Callable, Optional + +import spacy +from spacy.tokens import Doc, Token +from spacy.util import is_package + +# make sure the english language model capabilities are installed by the equivalent of: +# python -m spacy download en_core_web_sm +# Should be run once, initially. Download only starts if not already installed. +SPACY_MODEL = "en_core_web_sm" # small: "en_core_web_sm", large: "en_core_web_trf" +NLP = None # global var to hold the language model +if not is_package(SPACY_MODEL): + spacy.cli.download(SPACY_MODEL, False, False) + + +TOKEN_LABELS: dict[str, Callable] = { + # --- custom categories --- + "Starts with space": (lambda token: token.text.startswith(" ")), # bool + "Capitalized": (lambda token: token.text[0].isupper()), # bool + # --- POS (part-of-speech) categories --- + # They include the Universal POS tags (https://universaldependencies.org/u/pos/) + # -> "POS Tag": (lambda token: token.pos_), # 'NOUN', 'VB', .. + "Is Adjective": (lambda token: token.pos_ == "ADJ"), + "Is Adposition": (lambda token: token.pos_ == "ADP"), + "Is Adverb": (lambda token: token.pos_ == "ADV"), + "Is Auxiliary": (lambda token: token.pos_ == "AUX"), + "Is Coordinating conjuction": (lambda token: token.pos_ == "CCONJ"), + "Is Determiner": (lambda token: token.pos_ == "DET"), + "Is Interjunction": (lambda token: token.pos_ == "INTJ"), + "Is Noun": (lambda token: token.pos_ == "NOUN"), + "Is Numeral": (lambda token: token.pos_ == "NUM"), + "Is Particle": (lambda token: token.pos_ == "PART"), + "Is Pronoun": (lambda token: token.pos_ == "PRON"), + "Is Proper Noun": (lambda token: token.pos_ == "PROPN"), + "Is Punctuation": (lambda token: token.pos_ == "PUNCT"), + "Is Subordinating conjuction": (lambda token: token.pos_ == "SCONJ"), + "Is Symbol": (lambda token: token.pos_ == "SYM"), + "Is Verb": (lambda token: token.pos_ == "VERB"), + "Is Other": (lambda token: token.pos_ == "X"), + # --- dependency categories --- + # -> "Dependency": (lambda token: token.dep_), # 'nsubj', 'ROOT', 'dobj', .. + # "Is Subject": (lambda token: token.dep_ == "nsubj"), + # "Is Object": (lambda token: token.dep_ == "dobj"), + # "Is Root": ( + # lambda token: token.dep_ == "ROOT" + # ), # root of the sentence (often a verb) + # "Is auxiliary": (lambda token: token.dep_ == "aux"), + # --- Named entity recognition (NER) categories --- + # "Named Entity Type": (lambda token: token.ent_type_), # '', 'PERSON', 'ORG', 'GPE', .. + "Is Named Entity": (lambda token: token.ent_type_ != ""), +} + + +def explain_token_labels(token: Optional[Token] = None) -> None: + """ + Prints the explanation of a specific token's labels or of ALL + possible labels (POS, dependency, NER, ...), if no token is provided. + + Parameters + ---------- + token : Optional[Token], optional + The token, whose labels should be explained. If None, all labels + possible labels are explained, by default None. + """ + if token is not None: + # get token labels + labels = label_single_token(token) + print(" Explanation of token labels ".center(45, "-")) + print("Token text:".ljust(20), token.text) + print("Token dependency:".ljust(20), spacy.glossary.explain(token.dep_)) + print("Token POS:".ljust(20), spacy.glossary.explain(token.pos_)) + print(" Token labels ".center(45, "-")) + for i, (label_name, value) in enumerate(labels.items()): + print(f" {i:2} ", label_name.ljust(20), value) + + else: + glossary = spacy.glossary.GLOSSARY + print( + f"Explanation of all {len(glossary.keys())} token labels (POS, dependency, NER, ...):" + ) + for label, key in glossary.items(): + print(" ", label.ljust(10), key) + + +def label_single_token(token: Token | None) -> dict[str, bool]: + """ + Labels a single token. A token, that has been analyzed by the spaCy + library. + + Parameters + ---------- + token : Token | None + The token to be labelled. + + Returns + ------- + dict[str, bool] + Returns a dictionary with the token's labels as keys and their + corresponding boolean values. + """ + labels = dict() # The dict holding labels of a single token + # if token is None, then it is a '' empty strong token or similar + if token is None: + for label_name, category_check in TOKEN_LABELS.items(): + labels[label_name] = False + labels["Is Other"] = True + return labels + # all other cases / normal tokens + for label_name, category_check in TOKEN_LABELS.items(): + labels[label_name] = category_check(token) + return labels + + +def label_sentence(tokens: Doc | list[Token]) -> list[dict[str, bool]]: + """ + Labels spaCy Tokens in a sentence. Takes the context of the token into account + for dependency labels (e.g. subject, object, ...), IF dependency labels are turned on. + + Parameters + ---------- + tokens : list[Token] + A list of tokens. + + Returns + ------- + list[dict[str, bool]] + Returns a list of the tokens' labels. + """ + labelled_tokens = list() # list holding labels for all tokens of sentence + # if the list is empty it is because token is '' empty string or similar + if len(tokens) == 0: + labels = label_single_token(None) + labelled_tokens.append(labels) + return labelled_tokens + # in all other cases + for token in tokens: + labels = label_single_token(token) + labelled_tokens.append(labels) + return labelled_tokens + + +def label_batch_sentences( + sentences: list[str] | list[list[str]], + tokenized: bool = True, + verbose: bool = False, +) -> list[list[dict[str, bool]]]: + """ + Labels tokens in a sentence batchwise. Takes the context of the token into + account for dependency labels (e.g. subject, object, ...). + + Parameters + ---------- + sentences : list + A batch/list of sentences, each being a list of tokens. + tokenized : bool, optional + Whether the sentences are already tokenized, by default True. If the sentences + are full strings and not lists of tokens, then set to False. If true then `sentences` must be list[list[str]]. + verbose : bool, optional + Whether to print the tokens and their labels to the console, by default False. + + Returns + ------- + list[list[dict[str, bool]] + Returns a list of sentences. Each sentence contains a list of its + corresponding token length where each entry provides the labels/categories + for the token. Sentence -> Token -> Labels + """ + global NLP, SPACY_MODEL + + if NLP is None: + # Load english language model + NLP = spacy.load(SPACY_MODEL) + # labelled tokens, list holding sentences holding tokens holding corresponding token labels + labelled_sentences: list[list[dict[str, bool]]] = list() + + # go through each sentence in the batch + for sentence in sentences: + if tokenized: + # sentence is a list of tokens + doc = Doc(NLP.vocab, words=sentence) # type: ignore + # Apply the spaCy pipeline, except for the tokenizer + for name, proc in NLP.pipeline: + if name != "tokenizer": + doc = proc(doc) + else: + # sentence is a single string + doc = NLP(sentence) # type: ignore + + labelled_tokens = list() # list holding labels for all tokens of sentence + labelled_tokens = label_sentence(doc) + + # print the token and its labels to console + if verbose is True: + # go through each token in the sentence + for token, labelled_token in zip(doc, labelled_tokens): + print(f"Token: {token}") + print(" | ".join(list(TOKEN_LABELS.keys()))) + printable = [ + str(l).ljust(len(name)) for name, l in labelled_token.items() + ] + printable = " | ".join(printable) + print(printable) + print("---") + # add current sentence's tokens' labels to the list + labelled_sentences.append(labelled_tokens) + + if verbose is True: + print("\n") + + return labelled_sentences diff --git a/tests/eval/test_token_labelling.py b/tests/eval/test_token_labelling.py new file mode 100644 index 00000000..a727ddc0 --- /dev/null +++ b/tests/eval/test_token_labelling.py @@ -0,0 +1,114 @@ +import pytest +import spacy +from spacy.language import Language +from spacy.tokens import Doc + +import delphi.eval.token_labelling as tl + + +@pytest.fixture +def dummy_doc() -> tuple[str, Doc, dict[str, bool]]: + """ + Create a dummy Doc (list of Tokens) with specific attributes for testing purposes. + """ + nlp_dummy = Language() + + # Assume we're creating a dummy token with specific attributes + words = ["Peter", "is", "a", "person"] + spaces = [True, True, True, True] # No space after "dummy_token" + pos_tags = ["PROPN", "AUX", "DET", "NOUN"] # Part-of-speech tag + dep_tags = ["nsubj", "ROOT", "det", "attr"] # Dependency tag + ner_tags = ["PERSON", "", "", ""] # Named entity tag + + # Ensure the length of pos_tags and dep_tags matches the length of words + assert len(words) == len(pos_tags) == len(dep_tags) == len(ner_tags) + + # Create a Doc with one dummy token + doc = Doc(nlp_dummy.vocab, words=words, spaces=spaces) + + # Manually set POS, dependency and NER tags + for token, pos, dep, ner_tag in zip(doc, pos_tags, dep_tags, ner_tags): + token.pos_, token.dep_, token.ent_type_ = pos, dep, ner_tag + + # Token labels for "Peter" in the dummy doc + PETER_TOKEN_LABEL = { + "Starts with space": False, + "Capitalized": True, + "Is Adjective": False, + "Is Adposition": False, + "Is Adverb": False, + "Is Auxiliary": False, + "Is Coordinating conjuction": False, + "Is Determiner": False, + "Is Interjunction": False, + "Is Noun": False, + "Is Numeral": False, + "Is Particle": False, + "Is Pronoun": False, + "Is Proper Noun": True, + "Is Punctuation": False, + "Is Subordinating conjuction": False, + "Is Symbol": False, + "Is Verb": False, + "Is Other": False, + "Is Named Entity": True, + } + text = " ".join(words) + return text, doc, PETER_TOKEN_LABEL + + +def test_explain_token_labels(dummy_doc): + """ + Test the explain_token_labels function. + """ + # explain all labels + tl.explain_token_labels() + # print explanations for the first token in doc + text, doc, PETER_TOKEN_LABEL = dummy_doc + tl.explain_token_labels(doc[0]) + + +def test_label_single_token(dummy_doc): + """ + Test the label_single_token function. + """ + # create a dummy token + text, doc, PETER_TOKEN_LABEL = dummy_doc + token = doc[0] + # label the token + labels = tl.label_single_token(token) + # check if the labels are correct + assert labels == PETER_TOKEN_LABEL + + +def test_label_sentence(dummy_doc): + """ + Test the label_sentence function. + """ + text, doc, PETER_TOKEN_LABEL = dummy_doc + # label the sentence + labels = tl.label_sentence(doc) + # assert the first token is labeled correctly + assert labels[0] == PETER_TOKEN_LABEL + # iterate through tokens in doc + for token, label in zip(doc, labels): + assert label == tl.label_single_token(token) + + +def test_label_batch_sentences(dummy_doc): + """ + Test the label_batch_sentences function. + """ + # create a batch of sentences + text, doc, PETER_TOKEN_LABEL = dummy_doc + text = text.split(" ") + batch = [text, text, text] + # label the batch + labels = tl.label_batch_sentences(batch, tokenized=True) + # assert the first token is labeled correctly + assert labels[0][0] == PETER_TOKEN_LABEL + assert labels[1][0] == PETER_TOKEN_LABEL + assert labels[2][0] == PETER_TOKEN_LABEL + # iterate through tokens in doc + for token, label in zip(doc, labels[0]): + assert label == tl.label_single_token(token)