Skip to content

Commit

Permalink
support lagent chat
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla committed Oct 9, 2023
1 parent 8299098 commit ee89641
Showing 1 changed file with 196 additions and 158 deletions.
354 changes: 196 additions & 158 deletions xtuner/tools/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import re
import sys

import torch
from peft import PeftModel
Expand Down Expand Up @@ -47,6 +49,8 @@ def parse_args():
help='Specify plugins to use')
parser.add_argument(
'--no-streamer', action='store_true', help='Whether to with streamer')
parser.add_argument(
'--lagent', action='store_true', help='Whether to use lagent')
parser.add_argument('--command-stop-word', default=None, help='Stop key')
parser.add_argument('--answer-stop-word', default=None, help='Stop key')
parser.add_argument(
Expand Down Expand Up @@ -103,174 +107,208 @@ def get_input():

def main():
args = parse_args()
torch.manual_seed(args.seed)

if args.with_plugins is None:
inner_thoughts_open = False
calculate_open = False
solve_open = False
search_open = False
else:
assert args.prompt_template == 'moss_sft'
from plugins import plugins_api
inner_thoughts_open = True
calculate_open = 'calculate' in args.with_plugins
solve_open = 'solve' in args.with_plugins
search_open = 'search' in args.with_plugins
# pre-import for api and model preparation
if calculate_open:
from plugins import calculate # noqa: F401
if solve_open:
from plugins import solve # noqa: F401
if search_open:
from plugins import search # noqa: F401
if args.lagent:
from lagent.actions import ActionExecutor, GoogleSearch
from lagent.agents import ReAct
from lagent.llms import HFTransformerCasualLM

torch.manual_seed(args.seed)
assert args.adapter is None, ('lagent does not support the external '
'adapter, please merge the model first!')
assert args.bits is None, 'lagent does not support quantized LLM'
try:
SERPER_API_KEY = os.environ['SERPER_API_KEY']
except Exception:
print('Please obtain the `SERPER_API_KEY` from https://serper.dev '
'and set it using `export SERPER_API_KEY=xxx`.')
sys.exit(1)

# build model
quantization_config = None
load_in_8bit = False
if args.bits == 4:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')
elif args.bits == 8:
load_in_8bit = True
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
quantization_config=quantization_config,
load_in_8bit=load_in_8bit,
device_map='auto',
offload_folder=args.offload_folder,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True)
if args.adapter is not None:
model = PeftModel.from_pretrained(
model, args.adapter, offload_folder=args.offload_folder)
print(f'Load adapter from {args.adapter}')
model.eval()
llm = HFTransformerCasualLM(args.model_name_or_path)
search_tool = GoogleSearch(api_key=SERPER_API_KEY)
chatbot = ReAct(
llm=llm,
action_executor=ActionExecutor(actions=[search_tool]),
)
while True:
text = get_input()
while text.strip() == 'RESET':
print('Log: History responses have been removed!')
chatbot._session_history = []
inputs = ''
text = get_input()
if text.strip() == 'EXIT':
print('Log: Exit!')
exit(0)
response = chatbot.chat(text)
print(response.response)
else:
if args.with_plugins is None:
inner_thoughts_open = False
calculate_open = False
solve_open = False
search_open = False
else:
assert args.prompt_template == 'moss_sft'
from plugins import plugins_api
inner_thoughts_open = True
calculate_open = 'calculate' in args.with_plugins
solve_open = 'solve' in args.with_plugins
search_open = 'search' in args.with_plugins
# pre-import for api and model preparation
if calculate_open:
from plugins import calculate # noqa: F401
if solve_open:
from plugins import solve # noqa: F401
if search_open:
from plugins import search # noqa: F401
# build model
quantization_config = None
load_in_8bit = False
if args.bits == 4:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')
elif args.bits == 8:
load_in_8bit = True
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
quantization_config=quantization_config,
load_in_8bit=load_in_8bit,
device_map='auto',
offload_folder=args.offload_folder,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True)
if args.adapter is not None:
model = PeftModel.from_pretrained(
model, args.adapter, offload_folder=args.offload_folder)
print(f'Load adapter from {args.adapter}')
model.eval()

Streamer, stop_criteria = get_chat_utils(model)
if args.no_streamer:
Streamer = None
Streamer, stop_criteria = get_chat_utils(model)
if args.no_streamer:
Streamer = None

command_stop_cr, answer_stop_cr = update_stop_criteria(
base=stop_criteria,
tokenizer=tokenizer,
command_stop_word=args.command_stop_word,
answer_stop_word=args.answer_stop_word)
command_stop_cr, answer_stop_cr = update_stop_criteria(
base=stop_criteria,
tokenizer=tokenizer,
command_stop_word=args.command_stop_word,
answer_stop_word=args.answer_stop_word)

gen_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=args.temperature > 0,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
gen_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=args.temperature > 0,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)

n_turn = 0
inputs = ''
while True:
text = get_input()
while text.strip() == 'RESET':
print('Log: History responses have been removed!')
n_turn = 0
inputs = ''
n_turn = 0
inputs = ''
while True:
text = get_input()
if text.strip() == 'EXIT':
print('Log: Exit!')
exit(0)
if args.prompt_template is not None:
template = PROMPT_TEMPLATE[args.prompt_template]
if 'INSTRUCTION_START' in template and n_turn == 0:
prompt_text = template['INSTRUCTION_START'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
else:
prompt_text = template['INSTRUCTION'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
if args.prompt_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
'- Inner thoughts: disabled.')
if not calculate_open:
prompt_text.replace(
'- Calculator: enabled. API: Calculate(expression)',
'- Calculator: disabled.')
if not solve_open:
prompt_text.replace(
'- Equation solver: enabled. API: Solve(equation)',
'- Equation solver: disabled.')
if not search_open:
prompt_text.replace(
'- Web search: enabled. API: Search(query)',
'- Web search: disabled.')
while text.strip() == 'RESET':
print('Log: History responses have been removed!')
n_turn = 0
inputs = ''
text = get_input()
if text.strip() == 'EXIT':
print('Log: Exit!')
exit(0)
if args.prompt_template is not None:
template = PROMPT_TEMPLATE[args.prompt_template]
if 'INSTRUCTION_START' in template and n_turn == 0:
prompt_text = template['INSTRUCTION_START'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
else:
prompt_text = template['INSTRUCTION'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
if args.prompt_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
'- Inner thoughts: disabled.')
if not calculate_open:
prompt_text.replace(('- Calculator: enabled. '
'API: Calculate(expression)'),
'- Calculator: disabled.')
if not solve_open:
prompt_text.replace(
'- Equation solver: enabled. API: Solve(equation)',
'- Equation solver: disabled.')
if not search_open:
prompt_text.replace(
'- Web search: enabled. API: Search(query)',
'- Web search: disabled.')

inputs += prompt_text
else:
inputs += text
ids = tokenizer.encode(inputs, return_tensors='pt')
streamer = Streamer(tokenizer) if Streamer is not None else None
if args.with_plugins is not None:
generate_output = model.generate(
inputs=ids.cuda(),
generation_config=gen_config,
streamer=streamer,
stopping_criteria=command_stop_cr).cpu()
generate_output_text = tokenizer.decode(
generate_output[0][len(ids[0]):])
if streamer is None:
end = '' if generate_output_text[-1] == '\n' else '\n'
print(generate_output_text, end=end)
pattern = r'<\|Commands\|>:(.*?)<eoc>'
command_text = ', '.join(re.findall(pattern, generate_output_text))
extent_text = plugins_api(
command_text,
calculate_open=calculate_open,
solve_open=solve_open,
search_open=search_open)
end = '' if extent_text[-1] == '\n' else '\n'
print(extent_text, end=end)
extent_text_ids = tokenizer.encode(
extent_text, return_tensors='pt', add_special_tokens=False)
new_ids = torch.cat((generate_output, extent_text_ids), dim=1)
new_streamer = Streamer(
tokenizer) if Streamer is not None else None
generate_output = model.generate(
inputs=new_ids.cuda(),
generation_config=gen_config,
streamer=new_streamer,
stopping_criteria=answer_stop_cr)
if streamer is None:
output_text = tokenizer.decode(
generate_output[0][len(new_ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
else:
generate_output = model.generate(
inputs=ids.cuda(),
generation_config=gen_config,
streamer=streamer,
stopping_criteria=answer_stop_cr)
if streamer is None:
output_text = tokenizer.decode(
inputs += prompt_text
else:
inputs += text
ids = tokenizer.encode(inputs, return_tensors='pt')
streamer = Streamer(tokenizer) if Streamer is not None else None
if args.with_plugins is not None:
generate_output = model.generate(
inputs=ids.cuda(),
generation_config=gen_config,
streamer=streamer,
stopping_criteria=command_stop_cr).cpu()
generate_output_text = tokenizer.decode(
generate_output[0][len(ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
inputs = tokenizer.decode(generate_output[0])
n_turn += 1
if len(generate_output[0]) >= args.max_new_tokens:
print('Remove the memory of history responses, since '
f'it exceeds the length limitation {args.max_new_tokens}.')
n_turn = 0
inputs = ''
if streamer is None:
end = '' if generate_output_text[-1] == '\n' else '\n'
print(generate_output_text, end=end)
pattern = r'<\|Commands\|>:(.*?)<eoc>'
command_text = ', '.join(
re.findall(pattern, generate_output_text))
extent_text = plugins_api(
command_text,
calculate_open=calculate_open,
solve_open=solve_open,
search_open=search_open)
end = '' if extent_text[-1] == '\n' else '\n'
print(extent_text, end=end)
extent_text_ids = tokenizer.encode(
extent_text, return_tensors='pt', add_special_tokens=False)
new_ids = torch.cat((generate_output, extent_text_ids), dim=1)
new_streamer = Streamer(
tokenizer) if Streamer is not None else None
generate_output = model.generate(
inputs=new_ids.cuda(),
generation_config=gen_config,
streamer=new_streamer,
stopping_criteria=answer_stop_cr)
if streamer is None:
output_text = tokenizer.decode(
generate_output[0][len(new_ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
else:
generate_output = model.generate(
inputs=ids.cuda(),
generation_config=gen_config,
streamer=streamer,
stopping_criteria=answer_stop_cr)
if streamer is None:
output_text = tokenizer.decode(
generate_output[0][len(ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
inputs = tokenizer.decode(generate_output[0])
n_turn += 1
if len(generate_output[0]) >= args.max_new_tokens:
print(
'Remove the memory of history responses, since '
f'it exceeds the length limitation {args.max_new_tokens}.')
n_turn = 0
inputs = ''


if __name__ == '__main__':
Expand Down

0 comments on commit ee89641

Please sign in to comment.