diff --git a/scripts/generate_logprobs.sh b/scripts/generate_logprobs.sh index 78087e49..14403730 100644 --- a/scripts/generate_logprobs.sh +++ b/scripts/generate_logprobs.sh @@ -1,10 +1,10 @@ #!/bin/bash -# Define the dataset split -DATASET_SPLIT="validation" # Change this to your desired dataset split - # Define the batch size -BATCH_SIZE=80 # Change this if you want to use a different batch size +BATCH_SIZE=80 # This worked well in my CPU, but 200 was too much +DATASET_NAME="delphi-suite/tinystories-v2-clean-tokenized" +TOKEN="hf_aaaaaaaaaaaaaaaaaaaaaaaaa" # your Hugging Face API token + # List of models declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k" @@ -21,7 +21,7 @@ declare -a MODEL_NAMES=("delphi-suite/delphi-llama2-100k" for MODEL_NAME in "${MODEL_NAMES[@]}" do echo "Processing $MODEL_NAME" - python inference.py "$MODEL_NAME" "$DATASET_SPLIT" --batch_size "$BATCH_SIZE" + python scripts/inference_delete.py "$MODEL_NAME" --batch_size "$BATCH_SIZE" --token "$TOKEN" done echo "All models processed." diff --git a/scripts/inference.py b/scripts/inference.py index fe4a42de..8067c258 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -3,67 +3,34 @@ import pandas as pd import torch -from datasets import load_dataset +from datasets import Dataset, load_dataset from tqdm.auto import tqdm from transformers import AutoModelForCausalLM -torch.set_grad_enabled(False) - - -def get_correct_logprobs(model, samples_tok): - # logits: seq, pos, d_vocab - logits = model(samples_tok).logits - # logprobs: [batch_size, seq_length, vocab_size] - logprobs = torch.nn.functional.log_softmax(logits, dim=-1) - - # make probs a list of lists of correct token LOG probabilities. - list_logprob = [] - for i, sample in enumerate(samples_tok): - valid_length = len(sample) - 1 # Last token doesn't have a next token - sample_logprobs = logprobs[i, :valid_length, :] # [valid_length, vocab_size] - - # Extract the probabilities of the actual next tokens - next_tokens = sample[ - 1 : valid_length + 1 - ] # Tokens that follow each token in the sequence - correct_logprobs = sample_logprobs[torch.arange(valid_length), next_tokens] - - list_logprob.append(correct_logprobs) - return list_logprob +from delphi.eval.utils import get_all_and_next_logprobs, load_validation_dataset - # outputs a list of lists of correct token LOG probabilities. - # correct_logprobs = get_correct_logprobs(model, val_sequences[:10]) +torch.set_grad_enabled(False) -def main(model_name, dataset_split, batch_size): - val_ds = load_dataset( - "delphi-suite/tinystories-v2-clean-tokenized", split=dataset_split - ) - # val_ds[0]["tokens"] # access first sample +def main(model_name, batch_size, dataset_name, token): + val_ds = load_validation_dataset(dataset_name) # model accepts 2D tensors (batch_size, seq_len) val_sequences = torch.tensor([s["tokens"] for s in val_ds]) - - output_folder = "Correct_logprobs" - os.makedirs(output_folder, exist_ok=True) - - # Initialize an empty DataFrame to accumulate log probabilities - accumulated_df = pd.DataFrame() - model = AutoModelForCausalLM.from_pretrained(model_name) - - # Loop over the validation dataset in batches - for i in tqdm(range(0, len(val_sequences), batch_size)): - batch_sequences = val_sequences[i : i + batch_size] - batch_logprobs = get_correct_logprobs(model, batch_sequences) - # Convert batch log probabilities to a DataFrame - batch_df = pd.DataFrame([logprob.tolist() for logprob in batch_logprobs]) - # Append the batch DataFrame to the accumulated DataFrame - accumulated_df = pd.concat([accumulated_df, batch_df], ignore_index=True) - - # Save the accumulated DataFrame to a Parquet file - output_file = os.path.join(output_folder, f'{model_name.replace("/", "-")}.parquet') - accumulated_df.to_parquet(output_file) + logprobs, next_logprobs = get_all_and_next_logprobs(model, val_sequences) + + df_dataset = pd.DataFrame({"logprobs": next_logprobs.tolist()}) + hf_dataset = Dataset.from_pandas(df_dataset) + + # change the repo_id to your hf username + # change the token in generate_logprobs.sh + hf_dataset.push_to_hub( + repo_id=f"transcendingvictor/{model_name.rsplit('/', 1)[-1]}-validation-logprobs", + split="validation", + private=False, + token=token, + ) if __name__ == "__main__": @@ -73,20 +40,25 @@ def main(model_name, dataset_split, batch_size): parser.add_argument( "model_name", type=str, help="Model name with or without delphi-suite/ prefix" ) - parser.add_argument( - "dataset_split", type=str, help="Dataset split (e.g., train, validation, test)" - ) parser.add_argument( "--batch_size", type=int, default=80, help="Batch size for processing (default: 80)", ) - + parser.add_argument( + "--dataset_name", + type=str, + help="Dataset name with or without delphi-suite/ prefix", + ) + parser.add_argument( + "--token", + type=str, + help="Hugging Face API token", + ) args = parser.parse_args() - # Default prefix handling if "/" not in args.model_name: args.model_name = "delphi-suite/" + args.model_name - main(args.model_name, args.dataset_split, args.batch_size) + main(args.model_name, args.batch_size, args.dataset_name, args.token)