forked from pytorch/torchchat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
torchchat.py
119 lines (98 loc) · 3.58 KB
/
torchchat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import subprocess
import sys
# MPS ops missing with Multimodal torchtune
# https://github.com/pytorch/torchtune/issues/1723
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from torchchat.cli.cli import (
add_arguments_for_verb,
arg_init,
check_args,
INVENTORY_VERBS,
KNOWN_VERBS,
)
default_device = "cpu"
if __name__ == "__main__":
# Initialize the top-level parser
parser = argparse.ArgumentParser(
prog="torchchat",
add_help=True,
)
subparsers = parser.add_subparsers(
dest="command",
help="The specific command to run",
)
subparsers.required = True
VERB_HELP = {
"chat": "Chat interactively with a model via the CLI",
"generate": "Generate responses from a model given a prompt",
"browser": "Chat interactively with a model in a locally hosted browser",
"export": "Export a model artifact to AOT Inductor or ExecuTorch",
"download": "Download model artifacts",
"list": "List all supported models",
"remove": "Remove downloaded model artifacts",
"where": "Return directory containing downloaded model artifacts",
"server": "[WIP] Starts a locally hosted REST server for model interaction",
"eval": "Evaluate a model via lm-eval",
}
for verb, description in VERB_HELP.items():
subparser = subparsers.add_parser(verb, help=description)
add_arguments_for_verb(subparser, verb)
# Now parse the arguments
args = parser.parse_args()
# Don't initialize for Inventory management subcommands
# TODO: Remove when arg_init is refactored
if args.command not in INVENTORY_VERBS:
args = arg_init(args)
logging.basicConfig(
format="%(message)s", level=logging.DEBUG if args.verbose else logging.INFO
)
if args.command == "chat":
# enable "chat"
args.chat = True
check_args(args, "chat")
from generate import main as generate_main
generate_main(args)
elif args.command == "browser":
print(
"\nTo test out the browser please use: streamlit run torchchat/usages/browser.py <args>\n"
)
elif args.command == "server":
check_args(args, "server")
from torchchat.usages.server import main as server_main
server_main(args)
elif args.command == "generate":
check_args(args, "generate")
from torchchat.generate import main as generate_main
generate_main(args)
elif args.command == "eval":
from torchchat.usages.eval import main as eval_main
eval_main(args)
elif args.command == "export":
check_args(args, "export")
from torchchat.export import main as export_main
export_main(args)
elif args.command == "download":
check_args(args, "download")
from torchchat.cli.download import download_main
download_main(args)
elif args.command == "list":
check_args(args, "list")
from torchchat.cli.download import list_main
list_main(args)
elif args.command == "where":
check_args(args, "where")
from torchchat.cli.download import where_main
where_main(args)
elif args.command == "remove":
check_args(args, "remove")
from torchchat.cli.download import remove_main
remove_main(args)
else:
parser.print_help()