Skip to content

Commit

Permalink
added full path save option
Browse files Browse the repository at this point in the history
  • Loading branch information
menamerai committed Feb 27, 2024
1 parent a48370b commit 4e490ac
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
16 changes: 13 additions & 3 deletions scripts/label_all_tokens.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
import pickle

from tqdm.auto import tqdm
Expand Down Expand Up @@ -32,10 +33,14 @@ def main():
default="delphi-suite/delphi-llama2-100k",
required=False,
)
parser.add_argument(
"--output",
help="Output path name. Must include at least output file name.",
default="labelled_token_ids_dict.pkl",
)
args = parser.parse_args()

# Access command-line arguments

model_name = args.model_name

print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n")
Expand Down Expand Up @@ -86,8 +91,13 @@ def main():
labelled_token_ids_dict[token_id] = labels[0][0]

# Save the labelled tokens to a file
filename = "labelled_token_ids_dict.pkl"
filepath = STATIC_ASSETS_DIR.joinpath(filename)
if os.path.split(args.output)[0] == "":
filepath = STATIC_ASSETS_DIR.joinpath(args.output)
print(f"Outputting file {args.output} to path {filepath}")
else:
filepath = os.path.expandvars(args.output)
print(f"Outputting to path {filepath}")

with open(f"{filepath}", "wb") as f:
pickle.dump(labelled_token_ids_dict, f)

Expand Down
32 changes: 29 additions & 3 deletions scripts/map_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,40 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"dataset_name", help="Dataset from huggingface to run token_map on"
"dataset_name",
help="Dataset from huggingface to run token_map on. Must be tokenized.",
default="delphi-suite/v0-tinystories-v2-clean-tokenized",
)
parser.add_argument(
"--output",
help="Output path name. Must include at least output file name.",
default="token_map.pkl",
)
parser.add_argument("--output", help="Output file name", default="token_map.pkl")
args = parser.parse_args()

print("\n", " MAP TOKENS TO POSITIONS ".center(50, "="), "\n")
print(f"You chose the dataset: {args.dataset_name}\n")

if os.path.split(args.output)[0] == "":
filepath = STATIC_ASSETS_DIR.joinpath(args.output)
print(f"Outputting file {args.output} to path\n\t{filepath}\n")
else:
filepath = os.path.expandvars(args.output)
print(f"Outputting to path\n\t{filepath}\n")

dataset = load_validation_dataset(args.dataset_name)

mapping = token_map(dataset)

with open(f"{STATIC_ASSETS_DIR}/{args.output}", "wb") as f:
with open(f"{filepath}", "wb") as f:
pickle.dump(mapping, file=f)

print(f"Token map saved to\n\t{filepath}\n")
print("Sanity check ... ", end="")

with open(f"{filepath}", "rb") as f:
pickled = pickle.load(f)

assert mapping == pickled
print("completed.")
print(" END ".center(50, "="))

0 comments on commit 4e490ac

Please sign in to comment.