From a19a2eccb4b1d4c5cb3869cda9b39f1591cd54d0 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:44:27 +0100 Subject: [PATCH] Add option to force BOS for ppl test --- test_inference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test_inference.py b/test_inference.py index c445f259..302a60ee 100644 --- a/test_inference.py +++ b/test_inference.py @@ -41,6 +41,7 @@ parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample") parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache") parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit cache") +parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)") parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)") parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt") parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens") @@ -257,6 +258,10 @@ eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer) eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0] + if args.eval_bos: + boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long) + eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1) + logprob_sum = 0.0 logprob_count = 0