From f6481feb34dd607dbab94e533fb83ba803eb3f26 Mon Sep 17 00:00:00 2001 From: menamerai Date: Tue, 27 Feb 2024 17:54:00 -0500 Subject: [PATCH] removed static dir refs in labeled token output --- scripts/label_all_tokens.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py index 7c0bf262..aa1ceadb 100644 --- a/scripts/label_all_tokens.py +++ b/scripts/label_all_tokens.py @@ -34,8 +34,8 @@ def main(): required=False, ) parser.add_argument( - "--output", - help="Output path name. Must include at least output file name.", + "--output_path", + help="Output path name.", default="labelled_token_ids_dict.pkl", ) args = parser.parse_args() @@ -65,7 +65,7 @@ def main(): # Save the list of all tokens to a file filename = "all_tokens_list.txt" filepath = STATIC_ASSETS_DIR.joinpath(filename) - with open(f"{filepath}", "w", encoding="utf-8") as f: + with open(str(filepath), "w", encoding="utf-8") as f: f.write(tokens_str) print(f"Saved the list of all tokens to:\n\t{filepath}\n") @@ -91,14 +91,14 @@ def main(): labelled_token_ids_dict[token_id] = labels[0][0] # Save the labelled tokens to a file - if os.path.split(args.output)[0] == "": - filepath = STATIC_ASSETS_DIR.joinpath(args.output) - print(f"Outputting file {args.output} to path {filepath}") + if os.path.dirname(args.output_path) == "": + filepath = os.path.join(os.getcwd(), args.output_path) else: - filepath = os.path.expandvars(args.output) - print(f"Outputting to path {filepath}") + filepath = args.output_path - with open(f"{filepath}", "wb") as f: + filepath = os.path.abspath(filepath) + + with open(filepath, "wb") as f: pickle.dump(labelled_token_ids_dict, f) print(f"Saved the labelled tokens to:\n\t{filepath}\n") @@ -106,7 +106,7 @@ def main(): # sanity check that The pickled and the original dict are the same print("Sanity check ...", end="") # load pickle - with open(f"{filepath}", "rb") as f: + with open(filepath, "rb") as f: pickled = pickle.load(f) # compare assert labelled_token_ids_dict == pickled