From d0991a3c39cd7040f37e3b4ceb090dd719e060ed Mon Sep 17 00:00:00 2001
From: menamerai <rphananh@gmail.com>
Date: Mon, 26 Feb 2024 14:07:20 -0500
Subject: [PATCH] added full path save option

---
 scripts/label_all_tokens.py | 16 +++++++++++++---
 scripts/map_tokens.py       | 33 ++++++++++++++++++++++++++++++---
 2 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py
index 6519eaca..7c0bf262 100644
--- a/scripts/label_all_tokens.py
+++ b/scripts/label_all_tokens.py
@@ -1,4 +1,5 @@
 import argparse
+import os
 import pickle
 
 from tqdm.auto import tqdm
@@ -32,10 +33,14 @@ def main():
         default="delphi-suite/delphi-llama2-100k",
         required=False,
     )
+    parser.add_argument(
+        "--output",
+        help="Output path name. Must include at least output file name.",
+        default="labelled_token_ids_dict.pkl",
+    )
     args = parser.parse_args()
 
     # Access command-line arguments
-
     model_name = args.model_name
 
     print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n")
@@ -86,8 +91,13 @@ def main():
         labelled_token_ids_dict[token_id] = labels[0][0]
 
     # Save the labelled tokens to a file
-    filename = "labelled_token_ids_dict.pkl"
-    filepath = STATIC_ASSETS_DIR.joinpath(filename)
+    if os.path.split(args.output)[0] == "":
+        filepath = STATIC_ASSETS_DIR.joinpath(args.output)
+        print(f"Outputting file {args.output} to path {filepath}")
+    else:
+        filepath = os.path.expandvars(args.output)
+        print(f"Outputting to path {filepath}")
+
     with open(f"{filepath}", "wb") as f:
         pickle.dump(labelled_token_ids_dict, f)
 
diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py
index 2acea2da..ad3579c3 100644
--- a/scripts/map_tokens.py
+++ b/scripts/map_tokens.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 
 import argparse
+import os
 import pickle
 
 from delphi.constants import STATIC_ASSETS_DIR
@@ -10,14 +11,40 @@
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="")
     parser.add_argument(
-        "dataset_name", help="Dataset from huggingface to run token_map on"
+        "dataset_name",
+        help="Dataset from huggingface to run token_map on. Must be tokenized.",
+        default="delphi-suite/v0-tinystories-v2-clean-tokenized",
+    )
+    parser.add_argument(
+        "--output",
+        help="Output path name. Must include at least output file name.",
+        default="token_map.pkl",
     )
-    parser.add_argument("--output", help="Output file name", default="token_map.pkl")
     args = parser.parse_args()
 
+    print("\n", " MAP TOKENS TO POSITIONS ".center(50, "="), "\n")
+    print(f"You chose the dataset: {args.dataset_name}\n")
+
+    if os.path.split(args.output)[0] == "":
+        filepath = STATIC_ASSETS_DIR.joinpath(args.output)
+        print(f"Outputting file {args.output} to path\n\t{filepath}\n")
+    else:
+        filepath = os.path.expandvars(args.output)
+        print(f"Outputting to path\n\t{filepath}\n")
+
     dataset = load_validation_dataset(args.dataset_name)
 
     mapping = token_map(dataset)
 
-    with open(f"{STATIC_ASSETS_DIR}/{args.output}", "wb") as f:
+    with open(f"{filepath}", "wb") as f:
         pickle.dump(mapping, file=f)
+
+    print(f"Token map saved to\n\t{filepath}\n")
+    print("Sanity check ... ", end="")
+
+    with open(f"{filepath}", "rb") as f:
+        pickled = pickle.load(f)
+
+    assert mapping == pickled
+    print("completed.")
+    print(" END ".center(50, "="))