Skip to content

Commit

Permalink
add command line arguments save_dir and output_format
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Feb 22, 2024
1 parent d970d8f commit 57606bd
Showing 1 changed file with 63 additions and 42 deletions.
105 changes: 63 additions & 42 deletions scripts/label_all_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,28 @@ def main():
default="delphi-suite/delphi-llama2-100k",
required=False,
)
parser.add_argument(
"--save_dir", type=str, help="Directory to save the results.", required=True
)
parser.add_argument(
"--output_format",
type=str,
help="Format to save the results in. Options: csv, pkl. Default: csv.",
default="csv",
required=False,
)
args = parser.parse_args()

# Access command-line arguments
# Directory to save the results
SAVE_DIR = Path("src/delphi/eval/")
SAVE_DIR = Path(args.save_dir)
SAVE_DIR.mkdir(parents=True, exist_ok=True) # create directory if it does not exist
model_name = args.model_name
output_format = args.output_format
assert output_format in [
"csv",
"pkl",
], f"Invalid output format. Allowed: csv, pkl. Got: {output_format}"

print("\n", " LABEL ALL TOKENS ".center(50, "="), "\n")
print(f"You chose the model: {model_name}\n")
Expand All @@ -61,7 +77,8 @@ def main():

# Save the list of all tokens to a file
filename = "all_tokens_list.txt"
filepath = SAVE_DIR / filename
# filepath = SAVE_DIR / filename # TODO: use the static files of python module
filepath = Path("src/delphi/eval/") / filename
with open(filepath, "w", encoding="utf-8") as f:
f.write(tokens_str)

Expand All @@ -87,50 +104,54 @@ def main():
# update the labelled_token_ids_dict with the new dict
labelled_token_ids_dict[token_id] = labels[0][0]

# Save the labelled tokens to a file
filename = "labelled_token_ids_dict.pkl"
filepath = SAVE_DIR / filename
with open(filepath, "wb") as f:
pickle.dump(labelled_token_ids_dict, f)
if output_format == "pkl":
# Save the labelled tokens to a file
filename = "labelled_token_ids_dict.pkl"
filepath = SAVE_DIR / filename
with open(filepath, "wb") as f:
pickle.dump(labelled_token_ids_dict, f)

print(f"Saved the labelled tokens to:\n\t{filepath}\n")
print(f"Saved the labelled tokens to:\n\t{filepath}\n")

# sanity check that The pickled and the original dict are the same
print("Sanity check ...", end="")
# load pickle
with open(filepath, "rb") as f:
pickled = pickle.load(f)
# compare
assert labelled_token_ids_dict == pickled
print(" completed.")
# sanity check that The pickled and the original dict are the same
print("Sanity check ...", end="")
# load pickle
with open(filepath, "rb") as f:
pickled = pickle.load(f)
# compare
assert labelled_token_ids_dict == pickled
print(" completed.")

# ----------- CSV ------------------------
print("\nCreating the CSV ...")
# Create a pandas dataframe / CSV from the label dict
df = pd.DataFrame(labelled_token_ids_dict.items(), columns=["token_id", "label"])
# split the label column into multiple columns
df = df.join(pd.DataFrame(df.pop("label").tolist()))
# Change datatype of columns to float
df = df.astype(int)

print("Sanity check pandas csv ...", end="")
# Perform sanity check, that the table was created correctly
for row_index, row_values in df.iterrows():
token_id = row_values.iloc[0]
label_pandas = list(
row_values.iloc[1:]
) # we exclude the token_id from the colum
label_dict = list(labelled_token_ids_dict[token_id].values())[:]
assert (
label_pandas == label_dict
), f"The dataframes are not equal for row {token_id}\n{label_pandas}\n{label_dict}"
print(" completed.")

# save the dataframe to a csv
filename = "labelled_token_ids_df.csv"
filepath = SAVE_DIR / filename
df.to_csv(filepath, index=False)
print(f"Saved the labelled tokens as CSV to:\n\t{filepath}\n")
if output_format == "csv":
print("\nCreating the CSV ...")
# Create a pandas dataframe / CSV from the label dict
df = pd.DataFrame(
labelled_token_ids_dict.items(), columns=["token_id", "label"]
)
# split the label column into multiple columns
df = df.join(pd.DataFrame(df.pop("label").tolist()))
# Change datatype of columns to float
df = df.astype(int)

print("Sanity check pandas csv ...", end="")
# Perform sanity check, that the table was created correctly
for row_index, row_values in df.iterrows():
token_id = row_values.iloc[0]
label_pandas = list(
row_values.iloc[1:]
) # we exclude the token_id from the colum
label_dict = list(labelled_token_ids_dict[token_id].values())[:]
assert (
label_pandas == label_dict
), f"The dataframes are not equal for row {token_id}\n{label_pandas}\n{label_dict}"
print(" completed.")

# save the dataframe to a csv
filename = "labelled_token_ids_df.csv"
filepath = SAVE_DIR / filename
df.to_csv(filepath, index=False)
print(f"Saved the labelled tokens as CSV to:\n\t{filepath}\n")

print(" END ".center(50, "="))

Expand Down

0 comments on commit 57606bd

Please sign in to comment.