Skip to content

Commit

Permalink
added token label script
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Feb 13, 2024
1 parent d50e206 commit cf0f174
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 18 deletions.
116 changes: 116 additions & 0 deletions scripts/label_all_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pickle
import sys
from pathlib import Path

from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from delphi.eval import token_labelling


def tokenize(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, sample_txt: str
) -> int:
# supposedly this can be different than prepending the bos token id
return tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0]


# Decode a sentence
def decode(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, token_ids: list[int]
) -> str:
return tokenizer.decode(token_ids, skip_special_tokens=True)


def main():
print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n")
# Access command-line arguments
args = sys.argv[1:]
# Directory to save the results
SAVE_DIR = Path("src/delphi/eval/")

# Check if arguments are provided
if len(args) == 0:
print("No arguments provided.")
return

if len(args) > 1:
print("Too many arguments provided.")
return

# Process arguments
model_name = args[0]

print(f"You chose the model: {model_name}\n")
print(
f"The language model will be loaded from Huggingface and its tokenizer used to do two things:\n\t1) Create a list of all tokens in the tokenizer's vocabulary.\n\t2) Label each token with its part of speech, dependency, and named entity recognition tags.\nThe respective results will be saved to files located at: '{SAVE_DIR}'\n"
)

# ================ (1) =================
print("(1) Create a list of all tokens in the tokenizer's vocabulary ...")

# Load the tokenizer from Huggingface
tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab_size = tokenizer.vocab_size
print("Loaded the tokenizer.\nThe vocab size is:", vocab_size)

# Create a list of all tokens in the tokenizer's vocabulary
tokens_str = "" # will hold all tokens and their ids
for i in range(tokenizer.vocab_size):
tokens_str += f"{i},{decode(tokenizer, i)}\n"

# Save the list of all tokens to a file
filename = "all_tokens_" + model_name.replace("/", "-") + ".txt"
filepath = SAVE_DIR / filename
with open(filepath, "w", encoding="utf-8") as f:
f.write(tokens_str)

print(f"Saved the list of all tokens to:\n\t{filepath}\n")

# ================ (2) =================
print("(2) Label each token ...")

# let's label each token
labelled_token_ids_dict: dict[int, dict[str, bool]] = {} # token_id: labels
max_token_id = tokenizer.vocab_size # stop at which token id, vocab size
batch_size = 500
# we iterate (batchwise) over all token_ids, individually takes too much time
for start in tqdm(range(0, max_token_id, batch_size), desc="Labelling tokens"):
# create a batch of token_ids
end = min(start + batch_size, max_token_id)
token_ids = list(range(start, end))
# decode the token_ids to get a list of tokens, a 'sentence'
tokens = decode(tokenizer, token_ids) # list of tokens == sentence
# put the sentence into a list, to make it a batch of sentences
sentences = [tokens]
# label the batch of sentences
labels = token_labelling.label_batch_sentences(
sentences, tokenized=True, verbose=False
)
# create a dict with the token_ids and their labels
labelled_sentence_dict = dict(zip(token_ids, labels[0]))
# update the labelled_token_ids_dict with the new dict
labelled_token_ids_dict.update(labelled_sentence_dict)

# Save the labelled tokens to a file
filename = "labelled_token_ids_dict_" + model_name.replace("/", "-") + ".pkl"
filepath = SAVE_DIR / filename
with open(filepath, "wb") as f:
pickle.dump(labelled_token_ids_dict, f)

print(f"Saved the labelled tokens to:\n\t{filepath}\n")

# sanity check that The pickled and the original dict are the same
print("Sanity check ...", end="")
# load pickle
with open(filepath, "rb") as f:
pickled = pickle.load(f)
# compare
assert labelled_token_ids_dict == pickled
print(" completed.")

print(" END ".center(50, "="))


if __name__ == "__main__":
main()
37 changes: 19 additions & 18 deletions src/delphi/eval/token_labelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,33 @@
"Starts with space": (lambda token: token.text.startswith(" ")), # bool
"Capitalized": (lambda token: token.text[0].isupper()), # bool
# --- POS (part-of-speech) categories ---
# They include the Universal POS tags (https://universaldependencies.org/u/pos/)
# -> "POS Tag": (lambda token: token.pos_), # 'NOUN', 'VB', ..
"Is Adjective": (lambda token: token.pos_ == "ADJ"), # redundant
"Is Adposition": (lambda token: token.pos_ == "ADP"), # redundant
"Is Adverb": (lambda token: token.pos_ == "ADV"), # redundant
"Is Auxiliary": (lambda token: token.pos_ == "AUX"), # redundant
"Is Coordinating conjuction": (lambda token: token.pos_ == "CCONJ"), # redundant
"Is Determiner": (lambda token: token.pos_ == "DET"), # redundant
"Is Interjunction": (lambda token: token.pos_ == "INTJ"), # redundant
"Is Noun": (lambda token: token.pos_ == "NOUN"), # redundant
"Is Numeral": (lambda token: token.pos_ == "NUM"), # redundant
"Is Particle": (lambda token: token.pos_ == "PART"), # redundant
"Is Pronoun": (lambda token: token.pos_ == "PRON"), # redundant
"Is Proper Noun": (lambda token: token.pos_ == "PROPN"), # redundant
"Is Punctuation": (lambda token: token.pos_ == "PUNCT"), # redundant
"Is Subordinating conjuction": (lambda token: token.pos_ == "SCONJ"), # redundant
"Is Symbol": (lambda token: token.pos_ == "SYM"), # redundant
"Is Verb": (lambda token: token.pos_ == "VERB"), # redundant
"Is Other": (lambda token: token.pos_ == "X"), # redundant
"Is Adjective": (lambda token: token.pos_ == "ADJ"),
"Is Adposition": (lambda token: token.pos_ == "ADP"),
"Is Adverb": (lambda token: token.pos_ == "ADV"),
"Is Auxiliary": (lambda token: token.pos_ == "AUX"),
"Is Coordinating conjuction": (lambda token: token.pos_ == "CCONJ"),
"Is Determiner": (lambda token: token.pos_ == "DET"),
"Is Interjunction": (lambda token: token.pos_ == "INTJ"),
"Is Noun": (lambda token: token.pos_ == "NOUN"),
"Is Numeral": (lambda token: token.pos_ == "NUM"),
"Is Particle": (lambda token: token.pos_ == "PART"),
"Is Pronoun": (lambda token: token.pos_ == "PRON"),
"Is Proper Noun": (lambda token: token.pos_ == "PROPN"),
"Is Punctuation": (lambda token: token.pos_ == "PUNCT"),
"Is Subordinating conjuction": (lambda token: token.pos_ == "SCONJ"),
"Is Symbol": (lambda token: token.pos_ == "SYM"),
"Is Verb": (lambda token: token.pos_ == "VERB"),
"Is Other": (lambda token: token.pos_ == "X"),
# --- dependency categories ---
# -> "Dependency": (lambda token: token.dep_), # 'nsubj', 'ROOT', 'dobj', ..
# "Is Subject": (lambda token: token.dep_ == "nsubj"),
# "Is Object": (lambda token: token.dep_ == "dobj"),
# "Is Root": (
# lambda token: token.dep_ == "ROOT"
# ), # root of the sentence (often a verb)
# "Is auxiliary": (lambda token: token.dep_ == "aux"), # redundant
# "Is auxiliary": (lambda token: token.dep_ == "aux"),
# --- Named entity recognition (NER) categories ---
# "Named Entity Type": (lambda token: token.ent_type_), # '', 'PERSON', 'ORG', 'GPE', ..
"Is Named Entity": (lambda token: token.ent_type_ != ""),
Expand Down

0 comments on commit cf0f174

Please sign in to comment.