From 4797c0cb8dcf52064dbf7860364bccd8e154e4ab Mon Sep 17 00:00:00 2001 From: Joshua Wendland <80349780+joshuawe@users.noreply.github.com> Date: Tue, 5 Mar 2024 18:24:53 +0100 Subject: [PATCH] remove the save as pickle option for token label file --- scripts/label_all_tokens.py | 76 +++++++++--------------------- src/delphi/eval/token_labelling.py | 33 +++++-------- tests/eval/test_token_labelling.py | 14 ++---- 3 files changed, 38 insertions(+), 85 deletions(-) diff --git a/scripts/label_all_tokens.py b/scripts/label_all_tokens.py index f4655e80..1d340a0f 100644 --- a/scripts/label_all_tokens.py +++ b/scripts/label_all_tokens.py @@ -37,13 +37,6 @@ def main(): parser.add_argument( "--save-dir", type=str, help="Directory to save the results.", required=True ) - parser.add_argument( - "--output-format", - type=str, - help="Format to save the results in. Options: csv, pkl. Default: csv.", - default="csv", - required=False, - ) args = parser.parse_args() # Access command-line arguments @@ -51,11 +44,6 @@ def main(): save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) # create directory if it does not exist model_name = args.model_name - output_format = args.output_format - assert output_format in [ - "csv", - "pkl", - ], f"Invalid output format. Allowed: csv, pkl. Got: {output_format}" print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n") print(f"You chose the model: {model_name}\n") @@ -85,48 +73,28 @@ def main(): # ================ (2) ================= print("(2) Label each token ...") - if output_format == "pkl": - # Save the labelled tokens to a file - filename = "labelled_token_ids.pkl" - filepath = save_dir / filename - with open(filepath, "wb") as f: - pickle.dump(labelled_token_ids_dict, f) - - print(f"Saved the labelled tokens to:\n\t{filepath}\n") - - # sanity check that The pickled and the original dict are the same - print("Sanity check ...", end="") - # load pickle - with open(filepath, "rb") as f: - pickled = pickle.load(f) - # compare - assert labelled_token_ids_dict == pickled - print(" completed.") - - # ----------- CSV ------------------------ - if output_format == "csv": - print("\nCreating the CSV ...") - - df = token_labelling.convert_label_dict_to_df(labelled_token_ids_dict) - - print("Sanity check pandas csv ...", end="") - # Perform sanity check, that the table was created correctly - for row_index, row_values in df.iterrows(): - token_id = row_values.iloc[0] - label_pandas = list( - row_values.iloc[1:] - ) # we exclude the token_id from the colum - label_dict = list(labelled_token_ids_dict[token_id].values())[:] - assert ( - label_pandas == label_dict - ), f"The dataframes are not equal for row {token_id}\n{label_pandas}\n{label_dict}" - print(" completed.") - - # save the dataframe to a csv - filename = "labelled_token_ids.csv" - filepath = save_dir / filename - df.to_csv(filepath, index=False) - print(f"Saved the labelled tokens as CSV to:\n\t{filepath}\n") + print("\nCreating the CSV ...") + + df = token_labelling.convert_label_dict_to_df(labelled_token_ids_dict) + + print("Sanity check pandas csv ...", end="") + # Perform sanity check, that the table was created correctly + for row_index, row_values in df.iterrows(): + token_id = row_values.iloc[0] + label_pandas = list( + row_values.iloc[1:] + ) # we exclude the token_id from the colum + label_dict = list(labelled_token_ids_dict[token_id].values())[:] + assert ( + label_pandas == label_dict + ), f"The dataframes are not equal for row {token_id}\n{label_pandas}\n{label_dict}" + print(" completed.") + + # save the dataframe to a csv + filename = "labelled_token_ids.csv" + filepath = save_dir / filename + df.to_csv(filepath, index=False) + print(f"Saved the labelled tokens as CSV to:\n\t{filepath}\n") print(" END ".center(50, "=")) diff --git a/src/delphi/eval/token_labelling.py b/src/delphi/eval/token_labelling.py index cf7ffa26..40cb9f4a 100644 --- a/src/delphi/eval/token_labelling.py +++ b/src/delphi/eval/token_labelling.py @@ -265,7 +265,7 @@ def decode( def import_token_labels(path: str | Path): """ - Imports token labels from a file. May be a .pkl or a .csv + Imports token labels from a *.csv file. Parameters ---------- @@ -281,28 +281,21 @@ def import_token_labels(path: str | Path): path = Path(path) # make sure the file_type is compatible file_type = path.suffix - assert file_type in [ - ".csv", - ".pkl", - ], f"Invalid file type. Allowed: csv, pkl. Got: {file_type}" + assert ( + file_type == ".csv" + ), f"Invalid file type. Allowed: csv, pkl. Got: {file_type}" # make sure file exists if not path.exists(): raise FileNotFoundError(f"There is no file under {path}") - # load the file if CSV - if file_type == ".csv": - df = pd.read_csv(str(path)) - categories = list(df.columns[1:]) # excluding first column: token_id - loaded_label_dict: dict[int, dict[str, bool]] = {} - # go through each row and construct the dict - for _, row in df.iterrows(): - token_id = int(row["token_id"]) - labels = {cat: bool(row[cat] == 1) for cat in categories} - loaded_label_dict[token_id] = labels - - # load the file if a pickle - elif file_type == ".pkl": - with open(path, "rb") as f: - loaded_label_dict = pickle.load(f) + + df = pd.read_csv(str(path)) + categories = list(df.columns[1:]) # excluding first column: token_id + loaded_label_dict: dict[int, dict[str, bool]] = {} + # go through each row and construct the dict + for _, row in df.iterrows(): + token_id = int(row["token_id"]) + labels = {cat: bool(row[cat] == 1) for cat in categories} + loaded_label_dict[token_id] = labels return loaded_label_dict diff --git a/tests/eval/test_token_labelling.py b/tests/eval/test_token_labelling.py index 3c9f4e3e..e566057b 100644 --- a/tests/eval/test_token_labelling.py +++ b/tests/eval/test_token_labelling.py @@ -160,9 +160,7 @@ def test_label_tokens_from_tokenizer(): @pytest.mark.dependency(depends=["test_label_tokens_from_tokenizer"]) -@pytest.mark.parametrize( - "path", [Path("temp/token_labels.csv"), Path("temp/token_labels.pkl")] -) +@pytest.mark.parametrize("path", [Path("temp/token_labels.csv")]) def test_import_token_labels(path: Path): global labelled_token_ids_dict assert ( @@ -174,14 +172,8 @@ def test_import_token_labels(path: Path): # create the path path.parent.mkdir(parents=True, exist_ok=True) # save the file - if path.suffix == ".pkl": - with open(path, "wb") as file: - pickle.dump(labelled_token_ids_dict, file) - elif path.suffix == ".csv": - df = tl.convert_label_dict_to_df(labelled_token_ids_dict) - df.to_csv(path, index=False) - else: - raise ValueError("The file ending is incorrect.") + df = tl.convert_label_dict_to_df(labelled_token_ids_dict) + df.to_csv(path, index=False) # load the file with our function to be tested loaded_dict = tl.import_token_labels(path)