-
Notifications
You must be signed in to change notification settings - Fork 0
/
backend_api.py
72 lines (65 loc) · 2.44 KB
/
backend_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import argparse
from mosec import Server
from collections import OrderedDict
from sniffer_model import (SnifferGPTNeoModel,
SnifferGPT2Model,
SnifferGPTJModel,
SnifferWenZhongModel,
SnifferSkyWorkModel,
SnifferDaMoModel,
SnifferLlamaModel,
SnifferChatGLMModel,
SnifferAlpacaModel,
SnifferDollyModel,
SnifferStableLMRawModel,
SnifferStableLMTunedModel)
from backend_t5 import T5
MODEL_MAPPING_NAMES = OrderedDict([
("gpt2", SnifferGPT2Model),
("gptneo", SnifferGPTNeoModel),
("gptj", SnifferGPTJModel),
("llama", SnifferLlamaModel),
("wenzhong", SnifferWenZhongModel),
("skywork", SnifferSkyWorkModel),
("damo", SnifferDaMoModel),
("chatglm", SnifferChatGLMModel),
("alpaca", SnifferAlpacaModel),
("dolly", SnifferDollyModel),
("stablelm_raw", SnifferStableLMRawModel),
("stablelm_tuned", SnifferStableLMTunedModel)
])
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
default="gpt2",
help=
"The model to use. You can choose one of [gpt2, gptneo, gptj, llama, wenzhong, skywork, damo, chatglm, alpaca, dolly].",
)
parser.add_argument("--gpu",
type=str,
required=False,
default='0',
help="Set os.environ['CUDA_VISIBLE_DEVICES'].")
parser.add_argument("--port", help="mosec args.")
parser.add_argument("--timeout", help="mosec args.")
parser.add_argument("--debug", action="store_true", help="mosec args.")
return parser.parse_args()
if __name__ == "__main__":
# --model: [damo, gpt2, gptj, gptneo, wenzhong, skywork, llama]
# python backend_api.py --port 6006 --timeout 30000 --debug --model=damo --gpu=3
args = parse_args()
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.model == 't5':
server = Server()
server.append_worker(T5)
server.run()
else:
sniffer_model = MODEL_MAPPING_NAMES[args.model]
server = Server()
server.append_worker(sniffer_model)
server.run()