-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_passkey.py
141 lines (111 loc) · 4.55 KB
/
eval_passkey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import datasets
import time
import torch
from datetime import timedelta
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from accelerate import Accelerator, InitProcessGroupKwargs
from transformers import HfArgumentParser
from torch.utils.data import DataLoader
from src import ModelArgs, DatasetProcessFn, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, split_file_dir_name_ext, evaluate_perplexity
@dataclass
class Args(ModelArgs):
eval_data: str = field(
default="activation-beacon:lm/pg19.json",
metadata={'help': 'The evaluation json data path.'}
)
output_dir: str = field(
default="data/results/lm/",
metadata={'help': 'Output directory for results and logs.'}
)
retokenize: bool = field(
default=False,
metadata={'help': 'Retokenize the corpus?'}
)
tokenize_max_char: Optional[int] = field(
default=None,
metadata={'help': 'The number of chars to truncate.'}
)
batch_size: int = field(
default=1,
metadata={'help': 'Evaluation batch size.'}
)
padding_side: str = field(
default="right",
metadata={'help': 'Which side to pad?'}
)
stride: int = field(
default=2048,
metadata={'help': 'Streaming stride when evaluating perplexity.'}
)
max_sample_num: int = field(
default=100,
metadata={'help': 'How many samples to evaluate in eval_data?'}
)
min_length: Optional[int] = field(
default=None,
metadata={'help': 'Minimum length for input_ids.'}
)
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
# increase timeout to avoid error
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
model, tokenizer = get_model_and_tokenizer(args, accelerator=accelerator)
device = torch.device("cuda")
model.eval()
from numpy import random
def generate_prompt_landmark(n_garbage, seed):
"""Generates a text file and inserts an passkey at a random position."""
rnd_state = random.get_state()
random.seed(seed)
n_garbage_prefix = random.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
garbage_inf = " ".join([garbage] * 50000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = random.randint(1, 50000)
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
final_question = "What is the pass key? The pass key is"
lines = [
garbage_prefix,
information_line,
garbage_suffix,
task_description,
final_question,
]
random.set_state(rnd_state)
return "\n".join(lines), str(pass_key)
def passkey_retrieval_test(n_garbage=60000, seed=555):
#n_garbage=60000 results in ~16k tokens
prompt, answer = generate_prompt_landmark(n_garbage, seed)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
print(f"Prompt has {input_ids.shape[-1]} tokens")
answer_ids = tokenizer(answer, return_tensors="pt").input_ids[:, 1:] # drop BOS
if hasattr(model, "memory") and model.memory is not None:
model.memory.reset()
outputs = model.generate(
input_ids,
max_new_tokens=answer_ids.shape[-1],
num_beams=1,
do_sample=False,
)
# model_answer = outputs[0, -answer_ids.shape[-1]:].cpu()
model_answer = outputs[0,input_ids.shape[1]:].cpu()
# is_correct = (model_answer == answer_ids[0]).all().item()
is_correct = tokenizer.decode(answer_ids[0].cpu())==tokenizer.decode(model_answer.cpu())
print(f"The correct answer is {tokenizer.decode(answer_ids[0].cpu())}",flush=True)
print(f"The model answer is ::{tokenizer.decode(model_answer.cpu())}, is_correct : {is_correct}",flush=True)
return is_correct
num_tests = 10
lengths = [15000,30000,60000,375000,960000,1500000]
for length in lengths:
passed_tests = 0
for i in range(num_tests):
passed_tests += passkey_retrieval_test(n_garbage=length, seed=i)
print(f"Accuracy is {passed_tests/num_tests}",flush=True)