-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathFinetune_SFTTrainer_OneChatbotGPT2Vi.py
126 lines (95 loc) · 3.7 KB
/
Finetune_SFTTrainer_OneChatbotGPT2Vi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
# Author: Mr.Jack _ www.BICweb.vn
# Date: 25 May 2024
# https://huggingface.co/docs/trl/en/sft_trainer
import os, torch
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
os.environ["TOKENIZERS_PARALLELISM"] = "False"
# device = torch.device('mps')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_NAME = 'OneChatbotGPT2Vi'
# MODEL_NAME = './test_trainer'
# MODEL_NAME = './test_trainer/checkpoint-10'
print("MODEL_NAME:",MODEL_NAME)
# Step 1: Pretrained loading
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Step 2: Define the optimizer and learning rate scheduler
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
RANDOM_SEED = 42 # 3407
model.config.use_cache = False
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
print("Add new pad_token: [PAD]")
# text = 'Question: Xin chào\n Answer: Công ty BICweb kính chào quý khách!.'
text = "Question: Xin chào Answer: Dạ, em chào anh ạ!."
print("text:",text)
# data = [{"text": text}]
data = [{"input_ids": tokenizer.encode(text=text, add_special_tokens=True, return_tensors='pt')}]
from datasets import Dataset
dataset = Dataset.from_list(data)
# print(dataset)
# print(dataset[0])
dataset.set_format("torch")
# instruction_template = "Question:"
# response_template = " Answer:"
# collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer, mlm=False) # instruction_template=instruction_template,
EPOCHS = 15
LEARNING_RATE = 3e-4
OUTPUT_DIR = "test_trainer"
args_config = TrainingArguments(
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
logging_steps=1,
output_dir=OUTPUT_DIR,
seed=RANDOM_SEED, #42,
# max_grad_norm=9.9,
# resume_from_checkpoint='checkpoint-{0}'.format(EPOCHS), # default = None
warmup_steps=0,
weight_decay=0.01,
overwrite_output_dir=True,
# save_only_model=True, # default = False
save_steps=EPOCHS, # 10, -1 is mean every step
save_strategy= 'steps', # 'steps' 'epoch'
save_total_limit=1,
use_cpu=True, # default = False
# no_cuda=True, # default = False
# use_mps_device=True, # default = False
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="input_ids",
max_seq_length=256,
# optimizers=(optimizer, scheduler),
tokenizer=tokenizer,
# data_collator=collator,
args=args_config,
)
trainer.train()
# trainer.save_model(OUTPUT_DIR)
# trainer.model.save_pretrained(OUTPUT_DIR)
# trainer.tokenizer.save_pretrained(OUTPUT_DIR)
# Generate responses to new questions
model.eval()
def generate_answer(question):
# Encode the question using the tokenizer
input_ids = tokenizer.encode(question, add_special_tokens=False, return_tensors='pt').to(device)
# Generate the answer using the model
sample_output = model.generate(input_ids, pad_token_id=2, eos_token_id=50256, max_length=256, do_sample=True, top_k=100, top_p=1.0, temperature=0.6).to(device)
# Decode the generated answer using the tokenizer
answer = tokenizer.decode(sample_output[0], skip_special_tokens=True)
sentences = answer.split('.')
return sentences[0] # [answer]
# # Example usage
question = 'Question: Xin chào'
response = generate_answer(question)
print(f"\n{response}\n")
# import gc
# torch.cuda.empty_cache()
# torch.mps.empty_cache()
# gc.collect()