diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py index 6519eaca..7c0bf262 100644 --- a/scripts/label_all_tokens.py +++ b/scripts/label_all_tokens.py @@ -1,4 +1,5 @@ import argparse +import os import pickle from tqdm.auto import tqdm @@ -32,10 +33,14 @@ def main(): default="delphi-suite/delphi-llama2-100k", required=False, ) + parser.add_argument( + "--output", + help="Output path name. Must include at least output file name.", + default="labelled_token_ids_dict.pkl", + ) args = parser.parse_args() # Access command-line arguments - model_name = args.model_name print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n") @@ -86,8 +91,13 @@ def main(): labelled_token_ids_dict[token_id] = labels[0][0] # Save the labelled tokens to a file - filename = "labelled_token_ids_dict.pkl" - filepath = STATIC_ASSETS_DIR.joinpath(filename) + if os.path.split(args.output)[0] == "": + filepath = STATIC_ASSETS_DIR.joinpath(args.output) + print(f"Outputting file {args.output} to path {filepath}") + else: + filepath = os.path.expandvars(args.output) + print(f"Outputting to path {filepath}") + with open(f"{filepath}", "wb") as f: pickle.dump(labelled_token_ids_dict, f) diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py index 2acea2da..ad3579c3 100644 --- a/scripts/map_tokens.py +++ b/scripts/map_tokens.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import os import pickle from delphi.constants import STATIC_ASSETS_DIR @@ -10,14 +11,40 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="") parser.add_argument( - "dataset_name", help="Dataset from huggingface to run token_map on" + "dataset_name", + help="Dataset from huggingface to run token_map on. Must be tokenized.", + default="delphi-suite/v0-tinystories-v2-clean-tokenized", + ) + parser.add_argument( + "--output", + help="Output path name. Must include at least output file name.", + default="token_map.pkl", ) - parser.add_argument("--output", help="Output file name", default="token_map.pkl") args = parser.parse_args() + print("\n", " MAP TOKENS TO POSITIONS ".center(50, "="), "\n") + print(f"You chose the dataset: {args.dataset_name}\n") + + if os.path.split(args.output)[0] == "": + filepath = STATIC_ASSETS_DIR.joinpath(args.output) + print(f"Outputting file {args.output} to path\n\t{filepath}\n") + else: + filepath = os.path.expandvars(args.output) + print(f"Outputting to path\n\t{filepath}\n") + dataset = load_validation_dataset(args.dataset_name) mapping = token_map(dataset) - with open(f"{STATIC_ASSETS_DIR}/{args.output}", "wb") as f: + with open(f"{filepath}", "wb") as f: pickle.dump(mapping, file=f) + + print(f"Token map saved to\n\t{filepath}\n") + print("Sanity check ... ", end="") + + with open(f"{filepath}", "rb") as f: + pickled = pickle.load(f) + + assert mapping == pickled + print("completed.") + print(" END ".center(50, "="))