Skip to content

Commit

Permalink
corrected version after comments
Browse files Browse the repository at this point in the history
  • Loading branch information
transcendingvictor committed Feb 17, 2024
1 parent 25b2228 commit 77033f8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 62 deletions.
10 changes: 5 additions & 5 deletions scripts/generate_logprobs.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

# Define the dataset split
DATASET_SPLIT="validation" # Change this to your desired dataset split

# Define the batch size
BATCH_SIZE=80 # Change this if you want to use a different batch size
BATCH_SIZE=80 # This worked well in my CPU, but 200 was too much
DATASET_NAME="delphi-suite/tinystories-v2-clean-tokenized"
TOKEN="hf_aaaaaaaaaaaaaaaaaaaaaaaaa" # your Hugging Face API token


# List of models
declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k"
Expand All @@ -21,7 +21,7 @@ declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k"
for MODEL_NAME in "${MODEL_NAMES[@]}"
do
echo "Processing $MODEL_NAME"
python inference.py "$MODEL_NAME" "$DATASET_SPLIT" --batch_size "$BATCH_SIZE"
python scripts/inference_delete.py "$MODEL_NAME" --batch_size "$BATCH_SIZE" --token "$TOKEN"
done

echo "All models processed."
86 changes: 29 additions & 57 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,67 +3,34 @@

import pandas as pd
import torch
from datasets import load_dataset
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM

torch.set_grad_enabled(False)


def get_correct_logprobs(model, samples_tok):
# logits: seq, pos, d_vocab
logits = model(samples_tok).logits
# logprobs: [batch_size, seq_length, vocab_size]
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)

# make probs a list of lists of correct token LOG probabilities.
list_logprob = []
for i, sample in enumerate(samples_tok):
valid_length = len(sample) - 1 # Last token doesn't have a next token
sample_logprobs = logprobs[i, :valid_length, :] # [valid_length, vocab_size]

# Extract the probabilities of the actual next tokens
next_tokens = sample[
1 : valid_length + 1
] # Tokens that follow each token in the sequence
correct_logprobs = sample_logprobs[torch.arange(valid_length), next_tokens]

list_logprob.append(correct_logprobs)
return list_logprob
from delphi.eval.utils import get_all_and_next_logprobs, load_validation_dataset

# outputs a list of lists of correct token LOG probabilities.
# correct_logprobs = get_correct_logprobs(model, val_sequences[:10])
torch.set_grad_enabled(False)


def main(model_name, dataset_split, batch_size):
val_ds = load_dataset(
"delphi-suite/tinystories-v2-clean-tokenized", split=dataset_split
)
# val_ds[0]["tokens"] # access first sample
def main(model_name, batch_size, dataset_name, token):
val_ds = load_validation_dataset(dataset_name)

# model accepts 2D tensors (batch_size, seq_len)
val_sequences = torch.tensor([s["tokens"] for s in val_ds])

output_folder = "Correct_logprobs"
os.makedirs(output_folder, exist_ok=True)

# Initialize an empty DataFrame to accumulate log probabilities
accumulated_df = pd.DataFrame()

model = AutoModelForCausalLM.from_pretrained(model_name)

# Loop over the validation dataset in batches
for i in tqdm(range(0, len(val_sequences), batch_size)):
batch_sequences = val_sequences[i : i + batch_size]
batch_logprobs = get_correct_logprobs(model, batch_sequences)
# Convert batch log probabilities to a DataFrame
batch_df = pd.DataFrame([logprob.tolist() for logprob in batch_logprobs])
# Append the batch DataFrame to the accumulated DataFrame
accumulated_df = pd.concat([accumulated_df, batch_df], ignore_index=True)

# Save the accumulated DataFrame to a Parquet file
output_file = os.path.join(output_folder, f'{model_name.replace("/", "-")}.parquet')
accumulated_df.to_parquet(output_file)
logprobs, next_logprobs = get_all_and_next_logprobs(model, val_sequences)

df_dataset = pd.DataFrame({"logprobs": next_logprobs.tolist()})
hf_dataset = Dataset.from_pandas(df_dataset)

# change the repo_id to your hf username
# change the token in generate_logprobs.sh
hf_dataset.push_to_hub(
repo_id=f"transcendingvictor/{model_name.rsplit('/', 1)[-1]}-validation-logprobs",
split="validation",
private=False,
token=token,
)


if __name__ == "__main__":
Expand All @@ -73,20 +40,25 @@ def main(model_name, dataset_split, batch_size):
parser.add_argument(
"model_name", type=str, help="Model name with or without delphi-suite/ prefix"
)
parser.add_argument(
"dataset_split", type=str, help="Dataset split (e.g., train, validation, test)"
)
parser.add_argument(
"--batch_size",
type=int,
default=80,
help="Batch size for processing (default: 80)",
)

parser.add_argument(
"--dataset_name",
type=str,
help="Dataset name with or without delphi-suite/ prefix",
)
parser.add_argument(
"--token",
type=str,
help="Hugging Face API token",
)
args = parser.parse_args()

# Default prefix handling
if "/" not in args.model_name:
args.model_name = "delphi-suite/" + args.model_name

main(args.model_name, args.dataset_split, args.batch_size)
main(args.model_name, args.batch_size, args.dataset_name, args.token)

0 comments on commit 77033f8

Please sign in to comment.