-
Notifications
You must be signed in to change notification settings - Fork 36
/
llm.py
110 lines (102 loc) · 4.09 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import List, Optional, Mapping, Any
from functools import partial
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import os
import torch
class ChatGLM3(LLM):
model_path: str
max_length: int = 8192
temperature: float = 0.1
top_p: float = 0.7
history: List = []
streaming: bool = True
model: object = None
tokenizer: object = None
"""
def __init__(self, model_path: str, max_length: int = 8192, temperature: float = 0.1, top_p: float = 0.7, history: List = None, streaming: bool = True):
self.model_path = model_path
self.max_length = max_length
self.temperature = temperature
self.top_p = top_p
self.history = [] if history is None else history
self.streaming = streaming
self.model = None
self.tokenizer = None
"""
@property
def _llm_type(self) -> str:
return "chatglm3-6B"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
add_history: bool = False
) -> str:
if self.model is None or self.tokenizer is None:
raise RuntimeError("Must call `load_model()` to load model and tokenizer!")
if self.streaming:
text_callback = partial(StreamingStdOutCallbackHandler().on_llm_new_token, verbose=True)
resp = self.generate_resp(prompt, text_callback, add_history=add_history)
else:
resp = self.generate_resp(self, prompt, add_history=add_history)
return resp
def generate_resp(self, prompt, text_callback=None, add_history=True):
resp = ""
index = 0
if text_callback:
for i, (resp, _) in enumerate(self.model.stream_chat(
self.tokenizer,
prompt,
self.history,
max_length=self.max_length,
top_p=self.top_p,
temperature=self.temperature
)):
if add_history:
if i == 0:
self.history += [[prompt, resp]]
else:
self.history[-1] = [prompt, resp]
text_callback(resp[index:])
index = len(resp)
else:
resp, _ = self.model.chat(
self.tokenizer,
prompt,
self.history,
max_length=self.max_length,
top_p=self.top_p,
temperature=self.temperature
)
if add_history:
self.history += [[prompt, resp]]
return resp
def load_model(self):
if self.model is not None or self.tokenizer is not None:
return
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda().eval()
def load_model_from_checkpoint(self, checkpoint=None):
if self.model is not None or self.tokenizer is not None:
return
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half()
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
target_modules=['query_key_value'],
lora_alpha=32,
lora_dropout=0.1,
)
self.model = get_peft_model(self.model, peft_config).to("cuda")
if checkpoint=="text_classification":
model_dir = "./output/checkpoint-3000/"
peft_path = "{}/chatglm-lora.pt".format(model_dir)
if os.path.exists(peft_path):
self.model.load_state_dict(torch.load(peft_path), strict=False)