-
Notifications
You must be signed in to change notification settings - Fork 0
/
Inference.py
214 lines (183 loc) · 8.43 KB
/
Inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import json
import time
import logging
import csv
import gzip
from geobleu.Report import report_geobleu_dtw_gpt
from unsloth import FastLanguageModel
from datasets import Dataset
from unsloth.chat_templates import get_chat_template
import torch
import wandb # Import wandb
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
# Initialize wandb
wandb.init(project='Inference') # Set your project and run names, mode='offline'
# Initialize model and tokenizer
max_seq_length = 50000 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
logging.info("Initializing model and tokenizer...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="tangera/Llama3-8B-Mob", # YOUR MODEL YOU USED FOR TRAINING
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
# Set up the tokenizer's chat template
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3",
mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"},
)
logging.info("Model and tokenizer initialized successfully.")
def load_custom_dataset(file_path):
logging.info(f"Loading dataset from {file_path}...")
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Create a dataset structure suitable for the datasets library
formatted_data = {
"conversations": [],
"uids": []
}
for convo in data:
formatted_data["conversations"].append(convo["messages"])
formatted_data["uids"].append(convo["uid"])
# Convert the dataset into the Hugging Face datasets format
dataset = Dataset.from_dict(formatted_data)
logging.info(f"Dataset loaded with {len(dataset)} conversations.")
return dataset
def format_conversations(conversation):
formatted_convo = []
for message in conversation:
formatted_message = {
"from": message["role"],
"value": message["content"]
}
formatted_convo.append(formatted_message)
return formatted_convo
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = []
for convo in convos:
formatted_convo = format_conversations(convo)
try:
text = tokenizer.apply_chat_template(formatted_convo, tokenize=False, add_generation_prompt=False)
texts.append(text)
except Exception as e:
logging.error(f"Error processing conversation: {convo}")
raise e
return {"text": texts}
def run_inference(l_idx, r_idx, city):
logging.info(f"Starting inference from index {l_idx} to {r_idx - 1} on dataset {city}")
# Load the dataset
test_dataset = load_custom_dataset(city)
logging.info(f"Loaded dataset {city} with {len(test_dataset)} conversations")
# Select the subset
test_dataset = test_dataset.select(range(l_idx, r_idx))
logging.info(f"Selected subset of dataset from index {l_idx} to {r_idx}")
# Map formatting function
test_dataset = test_dataset.map(formatting_prompts_func, batched=True)
logging.info("Applied formatting prompts function to dataset")
# Initialize lists to store results
results = []
failed = []
total_conversations = len(test_dataset)
logging.info(f"Total conversations to process: {total_conversations}")
# Start a wandb Table to log results
wandb_table = wandb.Table(columns=["user_id", "status", "assistant_json"])
# Process each conversation
for i, conversation in enumerate(test_dataset):
start_time = time.time()
user_id = conversation["uids"]
logging.info(f"Processing conversation {l_idx + i}/{r_idx - 1}")
max_retries = 3
for attempt in range(1, max_retries + 1):
try:
messages = [
{"from": message["role"], "value": message["content"]}
for message in conversation["conversations"]
if message["role"] != 'assistant'
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True, # Must add for generation
return_tensors="pt",
).to("cuda")
reference_responses = [
message["content"]
for message in conversation["conversations"]
if message["role"] == 'assistant'
]
logging.info(f"Input sequence length: {inputs.size()}")
# if (inputs.size(1) > 20000):
# wandb.log({
# "user_id": user_id,
# "status": "failed"
# })
# continue
outputs = model.generate(input_ids=inputs, max_new_tokens=16400, use_cache=True)
generated_text = tokenizer.batch_decode(outputs)
logging.debug(f"Generated text: {generated_text}")
assistant_json_str = None # Initialize assistant_json_str for logging
for generated, reference in zip(generated_text, reference_responses):
split_text = generated.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
clean_text = split_text.replace(tokenizer.eos_token, "").strip()[7:-3] # Remove ending symbols
assistant_json = json.loads(clean_text)
reference_json = json.loads(reference.strip()[7:-3])
# check format
geobleu_val, dtw_val = report_geobleu_dtw_gpt(assistant_json['prediction'], reference_json['prediction'])
assistant_json_str = json.dumps(assistant_json) # Convert to string for logging
logging.debug(f"Assistant JSON: {assistant_json}")
for record in assistant_json['prediction']:
d, t, x, y = record # Unpack the list
results.append([user_id, d, t, x, y])
end_time = time.time()
elapsed_time = end_time - start_time
logging.info(f"User {user_id} processed in {elapsed_time:.2f}s")
wandb.log({"user_id": user_id, "processing_time": elapsed_time})
# Log success to wandb
wandb.log({
"user_id": user_id,
"status": "success"
})
# Add to wandb table
wandb_table.add_data(user_id, "success", assistant_json_str)
break # Break out of retry loop if successful
except Exception as e:
logging.error(f"Exception in conversation user {user_id}: {e}")
logging.error(f"Attempt {attempt}/{max_retries} failed")
if attempt == max_retries:
failed.append(user_id)
logging.error(f"Failed to process conversation user {user_id} after {max_retries} attempts")
# Log failure to wandb
wandb.log({
"user_id": user_id,
"status": "failed"
})
# Add to wandb table
wandb_table.add_data(user_id, "failed", None)
else:
time.sleep(1) # Wait a bit before retrying
# Optionally, you can log progress after each conversation
wandb.log({"progress": (i + 1) / total_conversations})
logging.info(f"Failed conversations: {failed}")
# Log final metrics to wandb
wandb.summary['failed_conversations'] = failed
# Log the wandb table
wandb.log({"results_table": wandb_table})
# Save results to CSV.GZ file
output_file = f'generated_{city[:-5]}_range({l_idx}, {r_idx}).csv.gz'
logging.info(f"Saving results to {output_file}")
with gzip.open(output_file, 'wt', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# Write header
writer.writerow(['user_id', 'd', 't', 'x', 'y'])
# Write rows
for row in results:
writer.writerow(row)
logging.info(f"Results saved to {output_file}")
# Finish the wandb run
wandb.finish()