diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py index eaa2ea41..e054fa07 100644 --- a/scripts/label_all_tokens.py +++ b/scripts/label_all_tokens.py @@ -1,5 +1,5 @@ +import argparse import pickle -import sys from pathlib import Path from tqdm.auto import tqdm @@ -17,30 +17,29 @@ def tokenize( # Decode a sentence def decode( - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, token_ids: list[int] + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, token_ids: int | list[int] ) -> str: return tokenizer.decode(token_ids, skip_special_tokens=True) def main(): - print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n") + # Setup argparse + parser = argparse.ArgumentParser(description="Tokenization and labeling utility.") + parser.add_argument( + "--model_name", + type=str, + help="Name of the model to use for tokenization and labeling.", + default="delphi-suite/delphi-llama2-100k", + required=False, + ) + args = parser.parse_args() + # Access command-line arguments - args = sys.argv[1:] # Directory to save the results SAVE_DIR = Path("src/delphi/eval/") + model_name = args.model_name - # Check if arguments are provided - if len(args) == 0: - print("No arguments provided.") - return - - if len(args) > 1: - print("Too many arguments provided.") - return - - # Process arguments - model_name = args[0] - + print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n") print(f"You chose the model: {model_name}\n") print( f"The language model will be loaded from Huggingface and its tokenizer used to do two things:\n\t1) Create a list of all tokens in the tokenizer's vocabulary.\n\t2) Label each token with its part of speech, dependency, and named entity recognition tags.\nThe respective results will be saved to files located at: '{SAVE_DIR}'\n" @@ -73,27 +72,22 @@ def main(): # let's label each token labelled_token_ids_dict: dict[int, dict[str, bool]] = {} # token_id: labels max_token_id = tokenizer.vocab_size # stop at which token id, vocab size - batch_size = 500 # we iterate (batchwise) over all token_ids, individually takes too much time - for start in tqdm(range(0, max_token_id, batch_size), desc="Labelling tokens"): - # create a batch of token_ids - end = min(start + batch_size, max_token_id) - token_ids = list(range(start, end)) + for token_id in tqdm(range(0, max_token_id), desc="Labelling tokens"): # decode the token_ids to get a list of tokens, a 'sentence' - tokens = decode(tokenizer, token_ids) # list of tokens == sentence + tokens = decode(tokenizer, token_id) # list of tokens == sentence # put the sentence into a list, to make it a batch of sentences - sentences = [tokens] # CHECK AGAIN + sentences = [tokens] # label the batch of sentences labels = token_labelling.label_batch_sentences( sentences, tokenized=True, verbose=False ) # create a dict with the token_ids and their labels - labelled_sentence_dict = dict(zip(token_ids, labels[0])) # update the labelled_token_ids_dict with the new dict - labelled_token_ids_dict.update(labelled_sentence_dict) + labelled_token_ids_dict[token_id] = labels[0][0] # Save the labelled tokens to a file - filename = "labelled_token_ids_dict_.pkl" + filename = "labelled_token_ids_dict.pkl" filepath = SAVE_DIR / filename with open(filepath, "wb") as f: pickle.dump(labelled_token_ids_dict, f)