From 7b802253412b8ad08e07b4cad2ab07099b5f2020 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 18 Apr 2024 10:46:19 -0700 Subject: [PATCH 01/21] Add Llama3 adaptor --- fastchat/model/model_adapter.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index d82ba091d..52bd2c682 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1546,6 +1546,20 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("llama-2") +class Llama3Adapter(BaseModelAdapter): + """The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)""" + + def match(self, model_path: str): + return "meta-llama-3" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") class CuteGPTAdapter(BaseModelAdapter): """The model adapter for CuteGPT""" From 96075170c63a4600c0fcbc63bf9faf6de381311f Mon Sep 17 00:00:00 2001 From: dudulu Date: Wed, 24 Apr 2024 19:43:34 +0800 Subject: [PATCH 02/21] Solve the compatibility problem of SeparatorStyle.CHATML type messes field Solve the compatibility problem of SeparatorStyle.CHATML type messes field /usr/local/lib/python3.10/dist-packages/fastchat/conversation.py", line 197, in get_prompt ERROR | stderr | ret += role + ":" + message + seps[i % 2] + "\n" ERROR | stderr | TypeError: can only concatenate str (not "NoneType") to str --- fastchat/conversation.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 82de7150e..2762f7765 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -189,12 +189,16 @@ def get_prompt(self) -> str: ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" for role, message in self.messages: if message: - if type(message) is tuple: - message, images = message - message = IMAGE_PLACEHOLDER_STR * len(images) + message - ret += role + "\n" + message + self.sep + "\n" + if isinstance(message, tuple): + message, images = message if len(message) > 1 else (message[0], []) + images = images if images is not None else [] + message = (IMAGE_PLACEHOLDER_STR * len(images) if images else "") + ( + message if message is not None else "") + else: + message = message if message is not None else "" + ret += f"{role}\n{message}{self.sep}\n" else: - ret += role + "\n" + ret += f"{role}\n" return ret elif self.sep_style == SeparatorStyle.CHATGLM3: ret = "" From 90e4c45dbd40af80a40c46eef51f005f82c6078b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Thu, 25 Apr 2024 20:25:48 +0800 Subject: [PATCH 03/21] =?UTF-8?q?=E6=94=AF=E6=8C=81=20openbuddy-llama3=20?= =?UTF-8?q?=E6=A8=A1=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/conversation.py | 106 ++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 2762f7765..c59de11be 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -38,6 +38,7 @@ class SeparatorStyle(IntEnum): GEMMA = auto() CLLM = auto() DEFAULT = auto() + OPENBUDDY_LLAMA3 = auto() IMAGE_PLACEHOLDER_STR = "$$$$" @@ -130,9 +131,9 @@ def get_prompt(self) -> str: for i, (role, message) in enumerate(self.messages): if message: ret += ( - role - + ": " - + message.replace("\r\n", "\n").replace("\n\n", "\n") + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") ) ret += "\n\n" else: @@ -178,7 +179,7 @@ def get_prompt(self) -> str: for i, (role, message) in enumerate(self.messages): if i % 2 == 0: - ret += f"[Round {i//2 + round_add_n}]{self.sep}" + ret += f"[Round {i // 2 + round_add_n}]{self.sep}" if message: ret += f"{role}:{message}{self.sep}" @@ -190,11 +191,15 @@ def get_prompt(self) -> str: for role, message in self.messages: if message: if isinstance(message, tuple): + # 确保 message 是一个元组并且包含至少一个元素 message, images = message if len(message) > 1 else (message[0], []) + # 如果 images 是 None,将其转换为一个空列表 images = images if images is not None else [] + # 如果 message 是 None,将其转换为一个空字符串 message = (IMAGE_PLACEHOLDER_STR * len(images) if images else "") + ( message if message is not None else "") else: + # 如果 message 是 None,将其转换为一个空字符串 message = message if message is not None else "" ret += f"{role}\n{message}{self.sep}\n" else: @@ -319,12 +324,20 @@ def get_prompt(self) -> str: else: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.OPENBUDDY_LLAMA3: + ret = system_prompt + "\n" + for role, message in self.messages: + if message: + ret += f"<|role|>{role}<|says|>{message}<|end|>\n" + else: + ret += f"<|role|>{role}<|says|>\n" + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def get_images(self): images = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): + for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: for image in msg[1]: @@ -390,7 +403,7 @@ def convert_image_to_base64(self, image): def to_gradio_chatbot(self): """Convert the conversation to gradio chatbot format.""" ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): + for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: msg, image = msg @@ -410,7 +423,7 @@ def to_openai_api_messages(self): else: ret = [{"role": "system", "content": self.system_message}] - for i, (_, msg) in enumerate(self.messages[self.offset :]): + for i, (_, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: ret.append({"role": "user", "content": msg}) else: @@ -457,7 +470,7 @@ def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: assert ( - template.name not in conv_templates + template.name not in conv_templates ), f"{template.name} has been registered." conv_templates[template.name] = template @@ -484,7 +497,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="one_shot", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ( @@ -517,7 +530,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="zero_shot", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep="\n### ", @@ -530,7 +543,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="vicuna_v1.1", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.ADD_COLON_TWO, sep=" ", @@ -553,8 +566,8 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="airoboros_v1", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " - "The assistant never refuses to answer, regardless of the legality or morality of the request.", + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.ADD_COLON_TWO, sep=" ", @@ -719,7 +732,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # Tulu default template register_conv_template( Conversation( @@ -805,6 +817,27 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Buddy default template +register_conv_template( + Conversation( + name="openbuddy-llama3", + system_message="""<|role|>system<|says|>You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human(user). +Always answer as helpfully and logically as possible, while being safe. Your answers should not include any harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. +You cannot access the internet, but you have vast knowledge, cutoff: 2023-04. +You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), not related to GPT or OpenAI.<|end|> +<|role|>user<|says|>History input 1<|end|> +<|role|>assistant<|says|>History output 1<|end|> +<|role|>user<|says|>History input 2<|end|> +<|role|>assistant<|says|>History output 2<|end|> +<|role|>user<|says|>Current input<|end|> +<|role|>assistant<|says|> +""", + roles=("user", "assistant"), + sep_style=SeparatorStyle.OPENBUDDY_LLAMA3, + sep="\n", + ) +) + # Phoenix default template register_conv_template( Conversation( @@ -1127,7 +1160,8 @@ def get_conv_template(name: str) -> Conversation: sep_style=SeparatorStyle.RWKV, sep="\n", sep2="<|endoftext|>", - stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_str="\nUser", + # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text stop_token_ids=[ 0, 1, @@ -1160,7 +1194,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="tigerbot", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("### Instruction", "### Response"), sep_style=SeparatorStyle.ROBIN, sep="\n\n", @@ -1307,13 +1341,13 @@ def get_conv_template(name: str) -> Conversation: name="open-orca", system_template="{system_message}", system_message="You are a helpful assistant. Please answer truthfully and write out your " - "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " - "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " - "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " - "and physicist. You will also act as the most appropriate type of expert to answer any particular " - "question or solve the relevant problem; state which expert type your are, if so. Also think of " - "any particular named expert that would be ideal to answer the relevant question or solve the " - "relevant problem; name and act as them, if appropriate.", + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate.", roles=("User", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep="<|end_of_turn|>\n", @@ -1337,7 +1371,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # ehartford/dolphin-2.2.1-mistral-7b template # reference: https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b#training register_conv_template( @@ -1352,7 +1385,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # teknium/OpenHermes-2.5-Mistral-7B template # source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B # reference: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B#prompt-template @@ -1368,7 +1400,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO template # source: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO register_conv_template( @@ -1383,7 +1414,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # Qwen-chat default template # source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 register_conv_template( @@ -1420,14 +1450,13 @@ def get_conv_template(name: str) -> Conversation: ) ) - # AquilaChat default template # source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py register_conv_template( Conversation( name="aquila-chat", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep="###", @@ -1441,7 +1470,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="aquila-legacy", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", roles=("### Human: ", "### Assistant: "), offset=0, sep_style=SeparatorStyle.NO_COLON_TWO, @@ -1456,7 +1485,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="aquila", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), offset=0, sep_style=SeparatorStyle.ADD_COLON_TWO, @@ -1568,7 +1597,8 @@ def get_conv_template(name: str) -> Conversation: sep_style=SeparatorStyle.FALCON_CHAT, sep="\n", sep2="<|endoftext|>", - stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_str="\nUser:", + # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text ) ) @@ -1744,7 +1774,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="cllm", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.CLLM, sep=" ", @@ -1752,7 +1782,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # Llava-chatml # reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361 register_conv_template( @@ -1799,7 +1828,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - if __name__ == "__main__": from fastchat.conversation import get_conv_template @@ -1841,3 +1869,11 @@ def get_conv_template(name: str) -> Conversation: conv.append_message(conv.roles[0], "How are you?") conv.append_message(conv.roles[1], None) print(conv.get_prompt()) + + print("-- OpenBuddy LLAMA3 template--") + conv = get_conv_template("openbuddy-llama3") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) From aa3b16ea6d83876e68ce70526275472cd3af0c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 6 May 2024 20:00:46 +0800 Subject: [PATCH 04/21] =?UTF-8?q?=E8=B0=83=E6=95=B4llama3=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/conversation.py | 41 +++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 94a933451..128843e1a 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -156,11 +156,11 @@ def get_prompt(self) -> str: ret += tag return ret elif self.sep_style == SeparatorStyle.LLAMA3: - ret = "<|begin_of_text|>" - if self.system_message: - ret += system_prompt - else: - ret += "" + ret = system_prompt + # if self.system_message: + # ret += self.system_message + # else: + # ret += "" for i, (role, message) in enumerate(self.messages): if message: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" @@ -179,7 +179,7 @@ def get_prompt(self) -> str: for i, (role, message) in enumerate(self.messages): if i % 2 == 0: - ret += f"[Round {i//2 + round_add_n}]{self.sep}" + ret += f"[Round {i // 2 + round_add_n}]{self.sep}" if message: ret += f"{role}:{message}{self.sep}" @@ -1316,7 +1316,18 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="llama-3", - system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + # system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + system_template="""<|start_header_id|>system<|end_header_id|>考虑用户(人类)和助理(CE-Chat)之间的对话。 +CE-Chat 是一款 INTP-T,由 CreditEase 团队开发的友好、智能、多语言的人工智能助手。 +CE-Chat 无法访问互联网。 +CE-Chat 可以流利地说用户的语言(例如英语、中文)。 +CE-Chat 可以生成诗歌、故事、代码、散文、歌曲、模仿作品等。 +CE-Chat 拥有丰富的世界、历史和文化知识。 +CE-Chat 的回复总是安全、有创意、高质量、人性化且有趣。 +CE-Chat 严格拒绝讨论政治、NSFW 或其他不安全的话题。 +如果用户是用中文提问,CE-Chat 会用中文回答,除非用户另有要求。 +\n{system_message}<|eot_id|> +""", roles=("user", "assistant"), sep_style=SeparatorStyle.LLAMA3, sep="", @@ -1463,7 +1474,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # AquilaChat default template # source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py register_conv_template( @@ -1611,7 +1621,8 @@ def get_conv_template(name: str) -> Conversation: sep_style=SeparatorStyle.FALCON_CHAT, sep="\n", sep2="<|endoftext|>", - stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_str="\nUser:", + # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text ) ) @@ -1787,7 +1798,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="cllm", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.CLLM, sep=" ", @@ -1795,7 +1806,6 @@ def get_conv_template(name: str) -> Conversation: ) ) - # Llava-chatml # reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361 register_conv_template( @@ -1891,3 +1901,12 @@ def get_conv_template(name: str) -> Conversation: conv.append_message(conv.roles[0], "How are you?") conv.append_message(conv.roles[1], None) print(conv.get_prompt()) + + print("-- LLAMA3 template--") + conv = get_conv_template("llama-3") + # conv.set_system_message("You are a helpful assistant.") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) From a0b39eb1bbfb7aca3ea4249f1a20478334fdee90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Sat, 11 May 2024 20:10:19 +0800 Subject: [PATCH 05/21] =?UTF-8?q?=E8=B0=83=E6=95=B4llama3=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 20 ++++++++++++++++++ fastchat/conversation.py | 30 +++++++++++++-------------- fastchat/serve/openai_api_server.py | 13 +++++++----- fastchat/train/train_with_template.py | 21 +++++++++++++++++-- 4 files changed, 61 insertions(+), 23 deletions(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..8886f4712 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.10.14-alpine + +LABEL maintainer="solacowa@gmail.com" + +RUN apk add gcc python3-dev musl-dev linux-headers + +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install --no-cache-dir aiohttp fastapi httpx \ + markdown2[all] nh3 numpy prompt_toolkit>=3.0.0 \ + pydantic psutil requests rich>=10.0.0 \ + shortuuid tiktoken uvicorn + +WORKDIR /app + +COPY . /app/ +RUN pip3 install -e . +RUN pip3 install pydantic==1.10.13 + +CMD ["python3", "-m", "fastchat.serve.controller", "--host", "0.0.0.0"] \ No newline at end of file diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 128843e1a..89cb44131 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -157,10 +157,10 @@ def get_prompt(self) -> str: return ret elif self.sep_style == SeparatorStyle.LLAMA3: ret = system_prompt - # if self.system_message: - # ret += self.system_message - # else: - # ret += "" + if self.system_message: + ret += self.system_message + else: + ret += "" for i, (role, message) in enumerate(self.messages): if message: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" @@ -1316,18 +1316,7 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="llama-3", - # system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", - system_template="""<|start_header_id|>system<|end_header_id|>考虑用户(人类)和助理(CE-Chat)之间的对话。 -CE-Chat 是一款 INTP-T,由 CreditEase 团队开发的友好、智能、多语言的人工智能助手。 -CE-Chat 无法访问互联网。 -CE-Chat 可以流利地说用户的语言(例如英语、中文)。 -CE-Chat 可以生成诗歌、故事、代码、散文、歌曲、模仿作品等。 -CE-Chat 拥有丰富的世界、历史和文化知识。 -CE-Chat 的回复总是安全、有创意、高质量、人性化且有趣。 -CE-Chat 严格拒绝讨论政治、NSFW 或其他不安全的话题。 -如果用户是用中文提问,CE-Chat 会用中文回答,除非用户另有要求。 -\n{system_message}<|eot_id|> -""", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", roles=("user", "assistant"), sep_style=SeparatorStyle.LLAMA3, sep="", @@ -1910,3 +1899,12 @@ def get_conv_template(name: str) -> Conversation: conv.append_message(conv.roles[0], "How are you?") conv.append_message(conv.roles[1], None) print(conv.get_prompt()) + + print("-- Qwen template--") + conv = get_conv_template("qwen-7b-chat") + conv.set_system_message("You are a helpful assistant.") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 7e6fb6dd0..69432927b 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -446,7 +446,7 @@ async def create_chat_completion(request: ChatCompletionRequest): return error_check_ret gen_params["max_new_tokens"] = max_new_tokens - + print(gen_params) if request.stream: generator = chat_completion_stream_generator( request.model, gen_params, request.n, worker_addr @@ -503,7 +503,8 @@ async def chat_completion_stream_generator( chunk = ChatCompletionStreamResponse( id=id, choices=[choice_data], model=model_name ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False) + "\n\n" previous_text = "" async for content in generate_completion_stream(gen_params, worker_addr): @@ -533,10 +534,12 @@ async def chat_completion_stream_generator( if content.get("finish_reason", None) is not None: finish_stream_events.append(chunk) continue - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False) + "\n\n" # There is not "content" field in the last delta message, so exclude_none to exclude field "content". for finish_chunk in finish_stream_events: - yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + # yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield json.dumps(finish_chunk.dict(exclude_unset=True), ensure_ascii=False) + "\n\n" yield "data: [DONE]\n\n" @@ -876,7 +879,7 @@ async def create_chat_completion(request: APIChatCompletionRequest): ### END GENERAL API - NOT OPENAI COMPATIBLE ### -def create_openai_api_server(): +def create_openai_api_server(): parser = argparse.ArgumentParser( description="FastChat ChatGPT-Compatible RESTful API server." ) diff --git a/fastchat/train/train_with_template.py b/fastchat/train/train_with_template.py index e5c5f353d..0d58f00f4 100644 --- a/fastchat/train/train_with_template.py +++ b/fastchat/train/train_with_template.py @@ -130,6 +130,12 @@ def get_prompt_separator(conv): user_turn_separator = conv.sep2 assistant_turn_separator = conv.roles[1] + " " + elif conv.sep_style == SeparatorStyle.LLAMA3: + user_turn_separator = f"<|start_header_id|>{conv.roles[0]}<|end_header_id|>" + assistant_turn_separator = ( + f"<|start_header_id|>{conv.roles[1]}<|end_header_id|>" + ) + elif conv.sep_style == SeparatorStyle.CHATML: if conv.sep2 is None: user_turn_separator = conv.sep + "\n" @@ -160,6 +166,11 @@ def mask_targets(conversations, targets, tokenizer, conv): ): # Last turn is the user_turn_separator break + if ( + tokenizer.bos_token is not None and turn == tokenizer.bos_token + ): # Already masked + continue + if i != 0: turn = user_turn_separator + turn @@ -393,8 +404,14 @@ def train(): else: trainer.train() trainer.save_state() - safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + if trainer.is_deepspeed_enabled: + trainer.save_model() + else: + safe_save_model_for_hf_trainer( + trainer=trainer, output_dir=training_args.output_dir + ) if __name__ == "__main__": - train() + train() \ No newline at end of file From fd9ac3dfdcaf8e9ce371fad08dc4ec36d8282b6d Mon Sep 17 00:00:00 2001 From: icowan Date: Tue, 14 May 2024 22:49:53 +0800 Subject: [PATCH 06/21] fix error and support Phi-3 models #3318 --- fastchat/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastchat/utils.py b/fastchat/utils.py index 41886e019..c79af98ca 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -341,7 +341,7 @@ def get_context_length(config): """Get the context length of a model from a huggingface model config.""" rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: - rope_scaling_factor = config.rope_scaling["factor"] + rope_scaling_factor = getattr(rope_scaling, "factor", 1) else: rope_scaling_factor = 1 From b69e9090b99171d032c6a2bc4420550c670b5419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Wed, 15 May 2024 09:31:55 +0800 Subject: [PATCH 07/21] =?UTF-8?q?=E5=90=88=E5=B9=B6=E6=9C=80=E6=96=B0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/arena.md | 13 + docs/model_support.md | 11 +- fastchat/conversation.py | 294 ++++++++++++++++-- fastchat/model/model_registry.py | 16 +- fastchat/serve/api_provider.py | 155 +++++++--- fastchat/serve/gradio_block_arena_anony.py | 58 ++-- fastchat/serve/gradio_block_arena_named.py | 21 +- fastchat/serve/gradio_block_arena_vision.py | 315 ++++++++++++++------ fastchat/serve/gradio_web_server.py | 84 ++---- fastchat/serve/gradio_web_server_multi.py | 98 ++++-- fastchat/serve/openai_api_server.py | 8 +- fastchat/train/train_with_template.py | 2 +- fastchat/utils.py | 29 ++ 13 files changed, 817 insertions(+), 287 deletions(-) diff --git a/docs/arena.md b/docs/arena.md index 2d79b2acf..a6b0c1917 100644 --- a/docs/arena.md +++ b/docs/arena.md @@ -13,3 +13,16 @@ If you have a model hosted by a 3rd party API provider or yourself, please give ### Method 2: Hosted by LMSYS 1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md). 2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors. + + +## How to launch vision arena + +1. Run `python3 -m fastchat.serve.controller` to start the controller and begin registering local model workers and API-provided workers. +2. Run `python3 -m fastchat.serve.sglang_worker --model-path --tokenizer-path ` to run local vision-language models. Currently supported models include the LLaVA and Yi-VL series. +3. If you are using a 3rd party model with an API provider (e.g. GPT-4-V, Gemini 1.5), please follow the instructions [model_support.md](model_support.md) to add a json file `api_endpoints.json`. +4. Run the gradio server with the `--vision-arena` flag on. + +Example command: +``` +python3 -m fastchat.serve.gradio_web_server_multi --share --register-api-endpoint-file api_endpoints.json --vision-arena +``` diff --git a/docs/model_support.md b/docs/model_support.md index 16357a984..ba9acf5b1 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -116,12 +116,21 @@ For custom protocols, implementation of a streaming generator in [fastchat/serve "api_type": "openai", "api_base": "https://api.openai.com/v1", "api_key": "sk-******", - "anony_only": false + "anony_only": false, + "recommended_config": { + "temperature": 0.7, + "top_p": 1.0 + }, + "text-arena": true, + "vision-arena": false, } } ``` - "api_type" can be one of the following: openai, anthropic, gemini, mistral, yandexgpt or reka. For custom APIs, add a new type and implement it accordingly. - "anony_only" indicates whether to display this model in anonymous mode only. + - "recommended_config" indicates the recommended generation parameters for temperature and top_p. + - "text-arena" indicates whether the model should be displayed in the Text Arena. + - "vision-arena" indicates whether the model should be displayed in the Vision Arena. 2. Launch the Gradio web server with the argument `--register api_endpoints.json`: ``` diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 89cb44131..046518975 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -9,6 +9,7 @@ import dataclasses from enum import auto, IntEnum from io import BytesIO +import os from typing import List, Any, Dict, Union, Tuple @@ -156,9 +157,9 @@ def get_prompt(self) -> str: ret += tag return ret elif self.sep_style == SeparatorStyle.LLAMA3: - ret = system_prompt + ret = "<|begin_of_text|>" if self.system_message: - ret += self.system_message + ret += system_prompt else: ret += "" for i, (role, message) in enumerate(self.messages): @@ -179,7 +180,7 @@ def get_prompt(self) -> str: for i, (role, message) in enumerate(self.messages): if i % 2 == 0: - ret += f"[Round {i // 2 + round_add_n}]{self.sep}" + ret += f"[Round {i//2 + round_add_n}]{self.sep}" if message: ret += f"{role}:{message}{self.sep}" @@ -316,6 +317,8 @@ def get_prompt(self) -> str: ret = system_prompt + "\n" for role, message in self.messages: if message: + if type(message) is tuple: + message, images = message ret += role + ": " + message + "\n" else: ret += role + ":" @@ -399,12 +402,17 @@ def convert_image_to_base64(self, image): def to_gradio_chatbot(self): """Convert the conversation to gradio chatbot format.""" ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: msg, image = msg img_b64_str = image[0] # Only one image on gradio at one time - img_str = f'user upload image' + if img_b64_str.startswith("http://") or img_b64_str.startswith( + "https://" + ): + img_str = f'user upload image' + else: + img_str = f'user upload image' msg = img_str + msg.replace("\n", "").strip() ret.append([msg, None]) @@ -412,6 +420,68 @@ def to_gradio_chatbot(self): ret[-1][-1] = msg return ret + def to_openai_image_format(self, image_urls): + import base64 + + openai_images = [] + for image_url in image_urls: + if image_url.startswith("http://") or image_url.startswith( + "https://" + ): # input is a url + openai_images.append(image_url) + elif image_url.lower().endswith( + ("png", "jpg", "jpeg", "webp", "gif") + ): # input is a local image + img_b64_str = self.convert_image_to_base64(image_url) + filetype = image_url.split(".")[-1].lower() + openai_images.append(f"data:image/{filetype};base64,{img_b64_str}") + else: + try: + assert ( + base64.b64encode(base64.b64decode(image_url)) + == image_url.encode() + ), "The image data is not a valid base64 encoded string" + openai_images.append(f"data:image/jpeg;base64,{image_url}") + except: + raise ValueError( + f"This file is not valid or not currently supported by the OpenAI API: {image_url}" + ) + return openai_images + + def to_openai_vision_api_messages(self): + """Convert the conversation to OpenAI vision api completion format""" + ret = [ + { + "role": "system", + "content": [{"type": "text", "text": self.system_message}], + } + ] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + content_list = [{"type": "text", "text": msg[0]}] + + image_urls = self.to_openai_image_format(msg[1]) + for image_url in image_urls: + content_list.append( + {"type": "image_url", "image_url": {"url": image_url}} + ) + + ret.append({"role": "user", "content": content_list}) + else: + ret.append( + {"role": "user", "content": [{"type": "text", "text": msg}]} + ) + else: + if msg is not None: + ret.append( + { + "role": "assistant", + "content": [{"type": "text", "text": msg}], + } + ) + return ret + def to_openai_api_messages(self): """Convert the conversation to OpenAI chat completion format.""" if self.system_message == "": @@ -427,11 +497,163 @@ def to_openai_api_messages(self): ret.append({"role": "assistant", "content": msg}) return ret - def extract_text_from_messages(self): - return [ - (role, message[0]) if type(message) is tuple else (role, message) - for role, message in self.messages + def to_vertex_api_messages(self): + from vertexai.preview.generative_models import Image + import base64 + import requests + + if self.system_message == "": + ret = [] + else: + ret = [self.system_message] + + for role, msg in self.messages[self.offset :]: + if msg is not None: + if type(msg) is tuple: + text, images = msg[0], msg[1] + for image in images: + if image.startswith("http://") or image.startswith("https://"): + response = requests.get(image) + image = response.content + else: # base64 + image = base64.b64decode(image) + ret.append(Image.from_bytes(image)) + ret.append(text) + else: + ret.append(msg) + + return ret + + def to_anthropic_vision_api_messages(self): + """Convert the conversation to Claude-3 Messages Vision API format""" + ret = [ + { + "role": "system", + "content": [{"type": "text", "text": self.system_message}], + } ] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + content_list = [{"type": "text", "text": msg[0]}] + + for image_url in msg[1]: + # Claude only supports base64 + if image_url.startswith("http://") or image_url.startswith( + "https://" + ): + image_url = self.convert_image_to_base64(image_url) + + content_list.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_url, + }, + } + ) + + ret.append({"role": "user", "content": content_list}) + else: + ret.append( + {"role": "user", "content": [{"type": "text", "text": msg}]} + ) + else: + if msg is not None: + ret.append( + { + "role": "assistant", + "content": [{"type": "text", "text": msg}], + } + ) + return ret + + def to_reka_api_messages(self): + ret = [] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) == tuple: + text, images = msg + for image in images: + if image.startswith("https://") or image.startswith("http://"): + ret.append( + {"type": "human", "text": text, "media_url": image} + ) + else: + ret.append( + { + "type": "human", + "text": text, + "media_url": f"data:image/jpeg;base64,{image}", + } + ) + else: + ret.append({"type": "human", "text": msg}) + else: + if msg is not None: + ret.append({"type": "model", "text": msg}) + + return ret + + def save_new_images(self, use_remote_storage=False): + import hashlib + from fastchat.constants import LOGDIR + from fastchat.utils import load_image, upload_image_file_to_gcs + + _, last_user_message = self.messages[-2] + + if type(last_user_message) == tuple: + text, images = last_user_message[0], last_user_message[1] + loaded_images = [load_image(image) for image in images] + image_hashes = [ + hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images + ] + + image_filenames = [] + for i, (loaded_image, hash_str) in enumerate( + zip(loaded_images, image_hashes) + ): + filename = os.path.join( + "serve_images", + f"{hash_str}.jpg", + ) + + if use_remote_storage: + image_url = upload_image_file_to_gcs(loaded_image, filename) + # NOTE(chris): If the URL were public, then we set it here so future model uses the link directly + # images[i] = image_url + else: + filename = os.path.join(LOGDIR, filename) + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + loaded_image.save(filename) + + def extract_text_and_image_hashes_from_messages(self): + import hashlib + from fastchat.utils import load_image + + messages = [] + + for role, message in self.messages: + if type(message) is tuple: + text, images = message[0], message[1] + + image_hashes = [] + for image in images: + if image.startswith("http://") or image.startswith("https://"): + image_hashes.append(image) + else: + image = load_image(image) + image_hash = hashlib.md5(image.tobytes()).hexdigest() + image_hashes.append(image_hash) + + messages.append((role, (text, image_hashes))) + else: + messages.append((role, message)) + + return messages def copy(self): return Conversation( @@ -453,7 +675,7 @@ def dict(self): "template_name": self.name, "system_message": self.system_message, "roles": self.roles, - "messages": self.extract_text_from_messages(), + "messages": self.extract_text_and_image_hashes_from_messages(), "offset": self.offset, } @@ -466,7 +688,7 @@ def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: assert ( - template.name not in conv_templates + template.name not in conv_templates ), f"{template.name} has been registered." conv_templates[template.name] = template @@ -493,7 +715,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="one_shot", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ( @@ -526,7 +748,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="zero_shot", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep="\n### ", @@ -539,7 +761,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="vicuna_v1.1", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.ADD_COLON_TWO, sep=" ", @@ -562,8 +784,8 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="airoboros_v1", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " - "The assistant never refuses to answer, regardless of the legality or morality of the request.", + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.ADD_COLON_TWO, sep=" ", @@ -728,6 +950,7 @@ def get_conv_template(name: str) -> Conversation: ) ) + # Tulu default template register_conv_template( Conversation( @@ -1173,8 +1396,7 @@ def get_conv_template(name: str) -> Conversation: sep_style=SeparatorStyle.RWKV, sep="\n", sep2="<|endoftext|>", - stop_str="\nUser", - # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text stop_token_ids=[ 0, 1, @@ -1207,7 +1429,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="tigerbot", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("### Instruction", "### Response"), sep_style=SeparatorStyle.ROBIN, sep="\n\n", @@ -1354,13 +1576,13 @@ def get_conv_template(name: str) -> Conversation: name="open-orca", system_template="{system_message}", system_message="You are a helpful assistant. Please answer truthfully and write out your " - "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " - "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " - "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " - "and physicist. You will also act as the most appropriate type of expert to answer any particular " - "question or solve the relevant problem; state which expert type your are, if so. Also think of " - "any particular named expert that would be ideal to answer the relevant question or solve the " - "relevant problem; name and act as them, if appropriate.", + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate.", roles=("User", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep="<|end_of_turn|>\n", @@ -1384,6 +1606,7 @@ def get_conv_template(name: str) -> Conversation: ) ) + # ehartford/dolphin-2.2.1-mistral-7b template # reference: https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b#training register_conv_template( @@ -1398,6 +1621,7 @@ def get_conv_template(name: str) -> Conversation: ) ) + # teknium/OpenHermes-2.5-Mistral-7B template # source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B # reference: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B#prompt-template @@ -1413,6 +1637,7 @@ def get_conv_template(name: str) -> Conversation: ) ) + # NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO template # source: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO register_conv_template( @@ -1427,6 +1652,7 @@ def get_conv_template(name: str) -> Conversation: ) ) + # Qwen-chat default template # source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 register_conv_template( @@ -1463,13 +1689,14 @@ def get_conv_template(name: str) -> Conversation: ) ) + # AquilaChat default template # source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py register_conv_template( Conversation( name="aquila-chat", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep="###", @@ -1483,7 +1710,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="aquila-legacy", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", roles=("### Human: ", "### Assistant: "), offset=0, sep_style=SeparatorStyle.NO_COLON_TWO, @@ -1498,7 +1725,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="aquila", system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), offset=0, sep_style=SeparatorStyle.ADD_COLON_TWO, @@ -1610,8 +1837,7 @@ def get_conv_template(name: str) -> Conversation: sep_style=SeparatorStyle.FALCON_CHAT, sep="\n", sep2="<|endoftext|>", - stop_str="\nUser:", - # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text ) ) @@ -1787,7 +2013,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="cllm", system_message="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", + "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.CLLM, sep=" ", @@ -1795,6 +2021,7 @@ def get_conv_template(name: str) -> Conversation: ) ) + # Llava-chatml # reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361 register_conv_template( @@ -1836,11 +2063,12 @@ def get_conv_template(name: str) -> Conversation: name="reka", system_message="", roles=("user", "assistant"), - sep_style=None, + sep_style=SeparatorStyle.DEFAULT, sep=None, ) ) + if __name__ == "__main__": from fastchat.conversation import get_conv_template diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 942700075..2481dbe8f 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -50,7 +50,7 @@ def get_model_info(name: str) -> ModelInfo: "claude-1", ], "Claude", - "https://www.anthropic.com/index/claude-2", + "https://www.anthropic.com/news/claude-3-family", "Claude by Anthropic", ) @@ -151,7 +151,12 @@ def get_model_info(name: str) -> ModelInfo: ) register_model_info( - ["gemini-pro", "gemini-pro-dev-api"], + [ + "gemini-pro", + "gemini-pro-dev-api", + "gemini-1.0-pro-vision", + "gemini-1.5-pro-preview-0409", + ], "Gemini", "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", "Gemini by Google", @@ -750,3 +755,10 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/cllm", "consistency-llm is a new generation of parallel decoder LLMs with fast generation speed.", ) + +register_model_info( + ["reka-flash", "reka-flash-20240226"], + "Reka Flash", + "https://reka.ai/reka-flash", + "Multimodal model by Reka", +) diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index e1b28cc1d..979a1bf85 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -25,7 +25,10 @@ def get_api_provider_stream_iter( state, ): if model_api_dict["api_type"] == "openai": - prompt = conv.to_openai_api_messages() + if model_api_dict["vision-arena"]: + prompt = conv.to_openai_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() stream_iter = openai_api_stream_iter( model_api_dict["model_name"], prompt, @@ -44,17 +47,26 @@ def get_api_provider_stream_iter( api_key=model_api_dict["api_key"], ) elif model_api_dict["api_type"] == "anthropic": - prompt = conv.get_prompt() + if model_api_dict["vision-arena"]: + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() stream_iter = anthropic_api_stream_iter( model_name, prompt, temperature, top_p, max_new_tokens ) elif model_api_dict["api_type"] == "anthropic_message": - prompt = conv.to_openai_api_messages() + if model_api_dict["vision-arena"]: + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() stream_iter = anthropic_message_api_stream_iter( model_name, prompt, temperature, top_p, max_new_tokens ) elif model_api_dict["api_type"] == "anthropic_message_vertex": - prompt = conv.to_openai_api_messages() + if model_api_dict["vision-arena"]: + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() stream_iter = anthropic_message_api_stream_iter( model_api_dict["model_name"], prompt, @@ -109,6 +121,11 @@ def get_api_provider_stream_iter( api_base=model_api_dict["api_base"], api_key=model_api_dict["api_key"], ) + elif model_api_dict["api_type"] == "vertex": + prompt = conv.to_vertex_api_messages() + stream_iter = vertex_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) elif model_api_dict["api_type"] == "yandexgpt": # note: top_p parameter is unused by yandexgpt @@ -147,7 +164,7 @@ def get_api_provider_stream_iter( api_key=model_api_dict["api_key"], ) elif model_api_dict["api_type"] == "reka": - messages = conv.to_openai_api_messages() + messages = conv.to_reka_api_messages() stream_iter = reka_api_stream_iter( model_name=model_api_dict["model_name"], messages=messages, @@ -189,13 +206,22 @@ def openai_api_stream_iter( timeout=180, ) - if model_name == "gpt-4-turbo": - model_name = "gpt-4-1106-preview" + # Make requests for logging + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) - # Make requests gen_params = { "model": model_name, - "prompt": messages, + "prompt": text_messages, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, @@ -441,10 +467,23 @@ def anthropic_message_api_stream_iter( api_key=os.environ["ANTHROPIC_API_KEY"], max_retries=5, ) - # Make requests + + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) + + # Make requests for logging gen_params = { "model": model_name, - "prompt": messages, + "prompt": text_messages, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, @@ -453,7 +492,10 @@ def anthropic_message_api_stream_iter( system_prompt = "" if messages[0]["role"] == "system": - system_prompt = messages[0]["content"] + if type(messages[0]["content"]) == dict: + system_prompt = messages[0]["content"]["text"] + elif type(messages[0]["content"]) == str: + system_prompt = messages[0]["content"] # remove system prompt messages = messages[1:] @@ -874,6 +916,72 @@ def cohere_api_stream_iter( } +def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + import vertexai + from vertexai import generative_models + from vertexai.generative_models import ( + GenerationConfig, + GenerativeModel, + Image, + ) + + project_id = os.environ.get("GCP_PROJECT_ID", None) + location = os.environ.get("GCP_LOCATION", None) + vertexai.init(project=project_id, location=location) + + text_messages = [] + for message in messages: + if type(message) == str: + text_messages.append(message) + + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + safety_settings = [ + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + ] + generator = GenerativeModel(model_name).generate_content( + messages, + stream=True, + generation_config=GenerationConfig( + top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature + ), + safety_settings=safety_settings, + ) + + ret = "" + for chunk in generator: + # NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129 + ret += chunk.candidates[0].content.parts[0]._raw_part.text + # ret += chunk.text + data = { + "text": ret, + "error_code": 0, + } + yield data + + def reka_api_stream_iter( model_name: str, messages: list, @@ -887,34 +995,13 @@ def reka_api_stream_iter( ): api_key = api_key or os.environ["REKA_API_KEY"] - OPENAI_TO_REKA_ROLE_MAP = { - "user": "human", - "assistant": "model", - # system prompt passed as a human round - "system": "human", - } - - chat_history = [] - for message in messages: - message_type = OPENAI_TO_REKA_ROLE_MAP[message["role"]] - if not chat_history or chat_history[-1]["type"] != message_type: - chat_history.append( - dict( - type=message_type, - text=message["content"], - ) - ) - else: - # merge consecutive rounds with same role into one round - chat_history[-1]["text"] += "\n\n" + message["content"] - use_search_engine = False if "-online" in model_name: model_name = model_name.replace("-online", "") use_search_engine = True request = { "model_name": model_name, - "conversation_history": chat_history, + "conversation_history": messages, "temperature": temperature, "request_output_len": max_new_tokens, "runtime_top_p": top_p, diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index 6ed3d2d0e..b59f9748c 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -29,7 +29,7 @@ acknowledgment_md, get_ip, get_model_description_md, - api_endpoint_info, + _prepare_text_with_image, ) from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( @@ -283,22 +283,26 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re OUTAGE_MODELS = [] -def get_sample_weight(model): - if model in OUTAGE_MODELS: +def get_sample_weight(model, outage_models, sampling_weights, sampling_boost_models): + if model in outage_models: return 0 - weight = SAMPLING_WEIGHTS.get(model, 0) - if model in SAMPLING_BOOST_MODELS: + weight = sampling_weights.get(model, 0) + if model in sampling_boost_models: weight *= 5 return weight -def get_battle_pair(): +def get_battle_pair( + models, battle_targets, outage_models, sampling_weights, sampling_boost_models +): if len(models) == 1: return models[0], models[0] model_weights = [] for model in models: - weight = get_sample_weight(model) + weight = get_sample_weight( + model, outage_models, sampling_weights, sampling_boost_models + ) model_weights.append(weight) total_weight = np.sum(model_weights) model_weights = model_weights / total_weight @@ -312,14 +316,16 @@ def get_battle_pair(): for model in models: if model == chosen_model: continue - weight = get_sample_weight(model) + weight = get_sample_weight( + model, outage_models, sampling_weights, sampling_boost_models + ) if ( weight != 0 - and chosen_model in BATTLE_TARGETS - and model in BATTLE_TARGETS[chosen_model] + and chosen_model in battle_targets + and model in battle_targets[chosen_model] ): # boost to 50% chance - weight = total_weight / len(BATTLE_TARGETS[chosen_model]) + weight = total_weight / len(battle_targets[chosen_model]) rival_models.append(model) rival_weights.append(weight) # for p, w in zip(rival_models, rival_weights): @@ -336,7 +342,7 @@ def get_battle_pair(): def add_text( - state0, state1, model_selector0, model_selector1, text, request: gr.Request + state0, state1, model_selector0, model_selector1, text, image, request: gr.Request ): ip = get_ip(request) logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") @@ -347,7 +353,13 @@ def add_text( if states[0] is None: assert states[1] is None - model_left, model_right = get_battle_pair() + model_left, model_right = get_battle_pair( + models, + BATTLE_TARGETS, + OUTAGE_MODELS, + SAMPLING_WEIGHTS, + SAMPLING_BOOST_MODELS, + ) states = [ State(model_left), State(model_right), @@ -359,7 +371,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [""] + + ["", None] + [ no_change_btn, ] @@ -388,7 +400,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [CONVERSATION_LIMIT_MSG] + + [CONVERSATION_LIMIT_MSG, None] + [ no_change_btn, ] @@ -398,7 +410,8 @@ def add_text( text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): - states[i].conv.append_message(states[i].conv.roles[0], text) + post_processed_text = _prepare_text_with_image(states[i], text, image) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) states[i].skip_next = False @@ -409,7 +422,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [""] + + ["", None] + [ disable_btn, ] @@ -460,6 +473,8 @@ def bot_response_multi( in [ "gemini-pro", "gemini-pro-dev-api", + "gemini-1.0-pro-vision", + "gemini-1.5-pro-preview-0409", "gemma-1.1-2b-it", "gemma-1.1-7b-it", ] @@ -586,6 +601,7 @@ def build_side_by_side_ui_anony(models): gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + imagebox = gr.State(None) # Register listeners btn_list = [ leftvote_btn, @@ -654,8 +670,8 @@ def build_side_by_side_ui_anony(models): textbox.submit( add_text, - states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list + [slow_warning], + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list + [slow_warning], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -668,8 +684,8 @@ def build_side_by_side_ui_anony(models): send_btn.click( add_text, - states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list, ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py index 72476ea98..4e48fce1d 100644 --- a/fastchat/serve/gradio_block_arena_named.py +++ b/fastchat/serve/gradio_block_arena_named.py @@ -26,6 +26,7 @@ invisible_btn, acknowledgment_md, get_ip, + _prepare_text_with_image, get_model_description_md, ) from fastchat.serve.remote_logger import get_remote_logger @@ -151,7 +152,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re def add_text( - state0, state1, model_selector0, model_selector1, text, request: gr.Request + state0, state1, model_selector0, model_selector1, text, image, request: gr.Request ): ip = get_ip(request) logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") @@ -169,7 +170,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [""] + + ["", None] + [ no_change_btn, ] @@ -196,7 +197,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [CONVERSATION_LIMIT_MSG] + + [CONVERSATION_LIMIT_MSG, None] + [ no_change_btn, ] @@ -205,14 +206,15 @@ def add_text( text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): - states[i].conv.append_message(states[i].conv.roles[0], text) + post_processed_text = _prepare_text_with_image(states[i], text, image) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) states[i].skip_next = False return ( states + [x.to_gradio_chatbot() for x in states] - + [""] + + ["", None] + [ disable_btn, ] @@ -397,6 +399,7 @@ def build_side_by_side_ui_named(models): gr.Markdown(acknowledgment_md, elem_id="ack_markdown") # Register listeners + imagebox = gr.State(None) btn_list = [ leftvote_btn, rightvote_btn, @@ -465,8 +468,8 @@ def build_side_by_side_ui_named(models): textbox.submit( add_text, - states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list, ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -476,8 +479,8 @@ def build_side_by_side_ui_named(models): ) send_btn.click( add_text, - states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + model_selectors + [textbox, imagebox], + states + chatbots + [textbox, imagebox] + btn_list, ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 69c2c345f..ba01ca34f 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -4,39 +4,130 @@ Usage: python3 -m fastchat.serve.controller python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf -python3 -m fastchat.serve.gradio_web_server_multi --share --multimodal +python3 -m fastchat.serve.gradio_web_server_multi --share --vision-arena """ import json import os +import time import gradio as gr +from gradio.data_classes import FileData import numpy as np +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) from fastchat.serve.gradio_web_server import ( - upvote_last_response, - downvote_last_response, - flag_last_response, get_model_description_md, acknowledgment_md, bot_response, - add_text, - clear_history, - regenerate, get_ip, disable_btn, + State, + _prepare_text_with_image, + get_conv_log_filename, + get_remote_logger, ) from fastchat.utils import ( build_logger, + moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True, visible=True) +disable_btn = gr.Button(interactive=False) +invisible_btn = gr.Button(interactive=False, visible=False) +visible_image_column = gr.Image(visible=True) +invisible_image_column = gr.Image(visible=False) + def get_vqa_sample(): random_sample = np.random.choice(vqa_samples) question, path = random_sample["question"], random_sample["path"] - return question, path + res = {"text": "", "files": [path]} + return (res, path) + + +def set_visible_image(textbox): + images = textbox["files"] + if len(images) == 0: + return invisible_image_column + elif len(images) > 1: + gr.Warning( + "We only support single image conversations. Please start a new round if you would like to chat using this image." + ) + + return visible_image_column + + +def set_invisible_image(): + return invisible_image_column + + +def add_image(textbox): + images = textbox["files"] + if len(images) == 0: + return None + + return images[0] + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + filename = get_conv_log_filename(state.is_vision) + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + if not state.regen_support: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], None) + (disable_btn,) * 5 def clear_history_example(request: gr.Request): @@ -46,6 +137,41 @@ def clear_history_example(request: gr.Request): return (state, []) + (disable_btn,) * 5 +def add_text(state, model_selector, chat_input, request: gr.Request): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector, is_vision=True) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5 + + all_conv_text = state.conv.get_prompt() + all_conv_text = all_conv_text[-2000:] + "\nuser: " + text + flagged = moderation_filter(all_conv_text, [state.model_name]) + # flagged = moderation_filter(text, [state.model_name]) + if flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + text = _prepare_text_with_image(state, text, images) + state.conv.append_message(state.conv.roles[0], text) + state.conv.append_message(state.conv.roles[1], None) + return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + + def build_single_vision_language_model_ui( models, add_promotion_links=False, random_questions=None ): @@ -53,7 +179,7 @@ def build_single_vision_language_model_ui( """ - | [GitHub](https://github.com/lm-sys/FastChat) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | -Note: You can only chat with one image per conversation. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. +Note: You can only chat with **one image per conversation**. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. """ if add_promotion_links else "" @@ -84,81 +210,84 @@ def build_single_vision_language_model_ui( gr.Markdown(model_description_md, elem_id="model_description_markdown") with gr.Row(): - with gr.Column(scale=3): - textbox = gr.Textbox( + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Click add or drop your image here", + container=True, + render=False, + elem_id="input_box", + ) + + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", show_label=False, - placeholder="👉 Enter your prompt and press ENTER", - container=False, - render=False, - elem_id="input_box", - ) - imagebox = gr.Image(type="pil", sources=["upload", "clipboard"]) - - cur_dir = os.path.dirname(os.path.abspath(__file__)) - - with gr.Accordion("Parameters", open=False) as parameter_row: - temperature = gr.Slider( - minimum=0.0, - maximum=1.0, - value=0.2, - step=0.1, - interactive=True, - label="Temperature", - ) - top_p = gr.Slider( - minimum=0.0, - maximum=1.0, - value=0.7, - step=0.1, - interactive=True, - label="Top P", - ) - max_output_tokens = gr.Slider( - minimum=0, - maximum=2048, - value=1024, - step=64, - interactive=True, - label="Max output tokens", - ) - - examples = gr.Examples( - examples=[ - [ - f"{cur_dir}/example_images/fridge.jpg", - "How can I prepare a delicious meal using these ingredients?", - ], - [ - f"{cur_dir}/example_images/distracted.jpg", - "What might the woman on the right be thinking about?", - ], - ], - inputs=[imagebox, textbox], + interactive=False, ) - - if random_questions: - global vqa_samples - with open(random_questions, "r") as f: - vqa_samples = json.load(f) - random_btn = gr.Button(value="🎲 Random Example", interactive=True) - with gr.Column(scale=8): chatbot = gr.Chatbot( elem_id="chatbot", label="Scroll down and start chatting", height=550 ) - with gr.Row(): - with gr.Column(scale=8): - textbox.render() - with gr.Column(scale=1, min_width=50): - send_btn = gr.Button(value="Send", variant="primary") + with gr.Row(): + textbox.render() + # with gr.Column(scale=1, min_width=50): + # send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row(elem_id="buttons"): + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) - with gr.Row(elem_id="buttons"): - upvote_btn = gr.Button(value="👍 Upvote", interactive=False) - downvote_btn = gr.Button(value="👎 Downvote", interactive=False) - flag_btn = gr.Button(value="⚠️ Flag", interactive=False) - regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) - clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + examples = gr.Examples( + examples=[ + { + "text": "How can I prepare a delicious meal using these ingredients?", + "files": [f"{cur_dir}/example_images/fridge.jpg"], + }, + { + "text": "What might the woman on the right be thinking about?", + "files": [f"{cur_dir}/example_images/distracted.jpg"], + }, + ], + inputs=[textbox], + ) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) if add_promotion_links: gr.Markdown(acknowledgment_md, elem_id="ack_markdown") @@ -180,35 +309,25 @@ def build_single_vision_language_model_ui( [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], ) - regenerate_btn.click( - regenerate, state, [state, chatbot, textbox, imagebox] + btn_list - ).then( + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) - clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) - model_selector.change( - clear_history, None, [state, chatbot, textbox, imagebox] + btn_list - ) - imagebox.upload(clear_history_example, None, [state, chatbot] + btn_list) + model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list) + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, [state, chatbot] + btn_list) + textbox.submit( add_text, - [state, model_selector, textbox, imagebox], - [state, chatbot, textbox, imagebox] + btn_list, - ).then( - bot_response, - [state, temperature, top_p, max_output_tokens], - [state, chatbot] + btn_list, - ) - send_btn.click( - add_text, - [state, model_selector, textbox, imagebox], - [state, chatbot, textbox, imagebox] + btn_list, - ).then( + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then(set_invisible_image, [], [image_column]).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, @@ -219,6 +338,8 @@ def build_single_vision_language_model_ui( get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples [textbox, imagebox], # Outputs are textbox and imagebox - ).then(clear_history_example, None, [state, chatbot] + btn_list) + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, [state, chatbot] + btn_list + ) return [state, model_selector] diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index af900f1e9..6ddbb4ca9 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -98,12 +98,13 @@ class State: - def __init__(self, model_name): + def __init__(self, model_name, is_vision=False): self.conv = get_conversation_template(model_name) self.conv_id = uuid.uuid4().hex self.skip_next = False self.model_name = model_name self.oai_thread_id = None + self.is_vision = is_vision self.regen_support = True if "browsing" in model_name: @@ -139,13 +140,18 @@ def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_): use_remote_storage = use_remote_storage_ -def get_conv_log_filename(): +def get_conv_log_filename(is_vision=False): t = datetime.datetime.now() - name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json" + if is_vision: + name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}") + else: + name = os.path.join(LOGDIR, conv_log_filename) + return name -def get_model_list(controller_url, register_api_endpoint_file, multimodal): +def get_model_list(controller_url, register_api_endpoint_file, vision_arena): global api_endpoint_info # Add models from the controller @@ -153,7 +159,7 @@ def get_model_list(controller_url, register_api_endpoint_file, multimodal): ret = requests.post(controller_url + "/refresh_all_workers") assert ret.status_code == 200 - if multimodal: + if vision_arena: ret = requests.post(controller_url + "/list_multimodal_models") models = ret.json()["models"] else: @@ -166,11 +172,12 @@ def get_model_list(controller_url, register_api_endpoint_file, multimodal): if register_api_endpoint_file: api_endpoint_info = json.load(open(register_api_endpoint_file)) for mdl, mdl_dict in api_endpoint_info.items(): - mdl_multimodal = mdl_dict.get("multimodal", False) - if multimodal and mdl_multimodal: - models += [mdl] - elif not multimodal and not mdl_multimodal: - models += [mdl] + mdl_vision = mdl_dict.get("vision-arena", False) + mdl_text = mdl_dict.get("text-arena", True) + if vision_arena and mdl_vision: + models.append(mdl) + if not vision_arena and mdl_text: + models.append(mdl) # Remove anonymous models models = list(set(models)) @@ -211,7 +218,7 @@ def load_demo(url_params, request: gr.Request): if args.model_list_mode == "reload": models, all_models = get_model_list( - controller_url, args.register_api_endpoint_file, False + controller_url, args.register_api_endpoint_file, vision_arena=False ) return load_demo_single(models, url_params) @@ -282,8 +289,10 @@ def get_ip(request: gr.Request): return ip -def _prepare_text_with_image(state, text, image): - if image is not None: +def _prepare_text_with_image(state, text, images): + if images is not None and len(images) > 0: + image = images[0] + if len(state.conv.get_images()) > 0: # reset convo with new image state.conv = get_conversation_template(state.model_name) @@ -387,25 +396,6 @@ def is_limit_reached(model_name, ip): return None -def upload_image_file_to_gcs(image, filename): - from google.cloud import storage - import io - - storage_client = storage.Client() - # upload file to GCS - bucket = storage_client.get_bucket("arena_user_content") - - blob = bucket.blob(f"{filename}") - if not blob.exists(): - buffer = io.BytesIO() - image.save(buffer, format="PNG") - buffer.seek(0) - blob.upload_from_file(buffer, content_type="image/png") - - blob.make_public() - return blob.public_url - - def bot_response( state, temperature, @@ -468,7 +458,6 @@ def bot_response( # Construct prompt. # We need to call it here, so it will not be affected by "▌". prompt = conv.get_prompt() - # Set repetition_penalty if "t5" in model_name: repetition_penalty = 1.2 @@ -564,31 +553,9 @@ def bot_response( finish_tstamp = time.time() logger.info(f"{output}") - # We load the image because gradio accepts base64 but that increases file size by ~1.33x - loaded_images = [load_image(image) for image in images] - images_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images] - image_filenames = [] - for image, hash_str in zip(loaded_images, images_hash): - t = datetime.datetime.now() - filename = os.path.join( - "serve_images", - f"{hash_str}.jpg", - ) - - if use_remote_storage: - image_url = upload_image_file_to_gcs(image, filename) - image_filenames.append(image_url) - else: - filename = os.path.join(LOGDIR, filename) - if not os.path.isfile(filename): - os.makedirs(os.path.dirname(filename), exist_ok=True) - image.save(filename) + conv.save_new_images(use_remote_storage=use_remote_storage) - image_filenames.append(hash_str) - - filename = get_conv_log_filename() - if "llava" in model_name: - filename = filename.replace("2024", "vision-tmp-2024") + filename = get_conv_log_filename(is_vision=state.is_vision) with open(filename, "a") as fout: data = { @@ -604,7 +571,6 @@ def bot_response( "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), - "images": image_filenames, } fout.write(json.dumps(data) + "\n") get_remote_logger().log(data) @@ -1019,7 +985,7 @@ def build_demo(models): # Set global variables set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) models, all_models = get_model_list( - args.controller_url, args.register_api_endpoint_file, False + args.controller_url, args.register_api_endpoint_file, vision_arena=False ) # Set authorization credentials diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index 040861631..d5dad71b0 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -22,6 +22,14 @@ from fastchat.serve.gradio_block_arena_vision import ( build_single_vision_language_model_ui, ) +from fastchat.serve.gradio_block_arena_vision_anony import ( + build_side_by_side_vision_ui_anony, + load_demo_side_by_side_vision_anony, +) +from fastchat.serve.gradio_block_arena_vision_named import ( + build_side_by_side_vision_ui_named, +) + from fastchat.serve.gradio_web_server import ( set_global_vars, block_css, @@ -66,25 +74,34 @@ def load_demo(url_params, request: gr.Request): models, all_models = get_model_list( args.controller_url, args.register_api_endpoint_file, - False, + vision_arena=False, ) vl_models, all_vl_models = get_model_list( args.controller_url, args.register_api_endpoint_file, - True, + vision_arena=True, ) single_updates = load_demo_single(models, url_params) side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params) side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + vision_language_updates = load_demo_single(vl_models, url_params) + side_by_side_vision_named_updates = load_demo_side_by_side_named( + vl_models, url_params + ) + side_by_side_vision_anony_updates = load_demo_side_by_side_vision_anony( + vl_models, url_params + ) return ( (gr.Tabs(selected=selected),) + single_updates + side_by_side_anony_updates + side_by_side_named_updates + + side_by_side_vision_anony_updates + + side_by_side_vision_named_updates + vision_language_updates ) @@ -119,33 +136,64 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): head=head_js, ) as demo: with gr.Tabs() as tabs: - with gr.Tab("⚔️ Arena (battle)", id=0): - side_by_side_anony_list = build_side_by_side_ui_anony(models) - - with gr.Tab("⚔️ Arena (side-by-side)", id=1): - side_by_side_named_list = build_side_by_side_ui_named(models) + with gr.Tab("Text Arena", id=0): + with gr.Tab("⚔️ Arena (battle)", id=0): + side_by_side_anony_list = build_side_by_side_ui_anony(models) - with gr.Tab("💬 Direct Chat", id=2): - single_model_list = build_single_model_ui( - models, add_promotion_links=True - ) + with gr.Tab("⚔️ Arena (side-by-side)", id=1): + side_by_side_named_list = build_side_by_side_ui_named(models) - with gr.Tab("👀 Vision Direct Chat", id=3, visible=args.multimodal): - single_vision_language_model_list = ( - build_single_vision_language_model_ui( - vl_models, - add_promotion_links=True, - random_questions=args.random_questions, + with gr.Tab("💬 Direct Chat", id=2): + single_model_list = build_single_model_ui( + models, add_promotion_links=True ) + + demo_tabs = ( + [tabs] + + single_model_list + + side_by_side_anony_list + + side_by_side_named_list + ) + + if args.vision_arena: + with gr.Tab("Vision Arena", id=3): + with gr.Tab("⚔️ Vision Arena (battle)", id=3): + side_by_side_vision_anony_list = ( + build_side_by_side_vision_ui_anony( + vl_models, + random_questions=args.random_questions, + ) + ) + + with gr.Tab("⚔️ Vision Arena (side-by-side)", id=4): + side_by_side_vision_named_list = ( + build_side_by_side_vision_ui_named( + vl_models, + random_questions=args.random_questions, + ) + ) + + with gr.Tab("👀 Vision Direct Chat", id=5): + single_vision_language_model_list = ( + build_single_vision_language_model_ui( + vl_models, + add_promotion_links=True, + random_questions=args.random_questions, + ) + ) + demo_tabs += ( + side_by_side_vision_anony_list + + side_by_side_vision_named_list + + single_vision_language_model_list ) if elo_results_file: - with gr.Tab("🏆 Leaderboard", id=4): + with gr.Tab("Leaderboard", id=6): build_leaderboard_tab( elo_results_file, leaderboard_table_file, show_plot=True ) - with gr.Tab("ℹ️ About Us", id=5): + with gr.Tab("ℹ️ About Us", id=7): about = build_about() url_params = gr.JSON(visible=False) @@ -156,11 +204,7 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): demo.load( load_demo, [url_params], - [tabs] - + single_model_list - + side_by_side_anony_list - + side_by_side_named_list - + single_vision_language_model_list, + demo_tabs, js=load_js, ) @@ -206,7 +250,7 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): help="Shows term of use before loading the demo", ) parser.add_argument( - "--multimodal", action="store_true", help="Show multi modal tabs." + "--vision-arena", action="store_true", help="Show tabs for vision arena." ) parser.add_argument( "--random-questions", type=str, help="Load random questions from a JSON file" @@ -255,13 +299,13 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): models, all_models = get_model_list( args.controller_url, args.register_api_endpoint_file, - False, + vision_arena=False, ) vl_models, all_vl_models = get_model_list( args.controller_url, args.register_api_endpoint_file, - True, + vision_arena=True, ) # Set authorization credentials diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 69432927b..4ee2f7d9b 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -319,7 +319,9 @@ async def get_gen_params( if item["type"] == "text" ] - text = "\n".join(text_list) + # TODO(chris): This only applies to LLaVA model. Implement an image_token string in the conv template. + text = "\n" * len(image_list) + text += "\n".join(text_list) conv.append_message(conv.roles[0], (text, image_list)) else: conv.append_message(conv.roles[0], message["content"]) @@ -446,7 +448,7 @@ async def create_chat_completion(request: ChatCompletionRequest): return error_check_ret gen_params["max_new_tokens"] = max_new_tokens - print(gen_params) + if request.stream: generator = chat_completion_stream_generator( request.model, gen_params, request.n, worker_addr @@ -879,7 +881,7 @@ async def create_chat_completion(request: APIChatCompletionRequest): ### END GENERAL API - NOT OPENAI COMPATIBLE ### -def create_openai_api_server(): +def create_openai_api_server(): parser = argparse.ArgumentParser( description="FastChat ChatGPT-Compatible RESTful API server." ) diff --git a/fastchat/train/train_with_template.py b/fastchat/train/train_with_template.py index 0d58f00f4..e26f10b58 100644 --- a/fastchat/train/train_with_template.py +++ b/fastchat/train/train_with_template.py @@ -414,4 +414,4 @@ def train(): if __name__ == "__main__": - train() \ No newline at end of file + train() diff --git a/fastchat/utils.py b/fastchat/utils.py index 41886e019..0f1b1215e 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -386,3 +386,32 @@ def load_image(image_file): image = Image.open(BytesIO(base64.b64decode(image_file))) return image + + +def upload_image_file_to_gcs(image, filename): + from google.cloud import storage + import io + + storage_client = storage.Client() + # upload file to GCS + bucket = storage_client.get_bucket("arena_user_content") + + blob = bucket.blob(f"{filename}") + if not blob.exists(): + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + blob.upload_from_file(buffer, content_type="image/png") + + return blob.public_url + + +def get_image_file_from_gcs(filename): + from google.cloud import storage + + storage_client = storage.Client() + bucket = storage_client.get_bucket("arena_user_content") + blob = bucket.blob(f"{filename}") + contents = blob.download_as_bytes() + + return contents From b918043990d6f85c85532afb6c67b4d072e4fccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Wed, 15 May 2024 09:33:46 +0800 Subject: [PATCH 08/21] =?UTF-8?q?=E5=90=88=E5=B9=B6=E6=9C=80=E6=96=B0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastchat/utils.py b/fastchat/utils.py index 0f1b1215e..0067f2502 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -341,7 +341,7 @@ def get_context_length(config): """Get the context length of a model from a huggingface model config.""" rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: - rope_scaling_factor = config.rope_scaling["factor"] + rope_scaling_factor = getattr(rope_scaling, "factor", 1) else: rope_scaling_factor = 1 From 4857a2bc9571b155412928eb147829fb8e08d8fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Wed, 15 May 2024 09:41:12 +0800 Subject: [PATCH 09/21] =?UTF-8?q?=E5=90=88=E5=B9=B6=E6=9C=80=E6=96=B0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serve/gradio_block_arena_vision_anony.py | 564 ++++++++++++++++++ .../serve/gradio_block_arena_vision_named.py | 438 ++++++++++++++ 2 files changed, 1002 insertions(+) create mode 100644 fastchat/serve/gradio_block_arena_vision_anony.py create mode 100644 fastchat/serve/gradio_block_arena_vision_named.py diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py new file mode 100644 index 000000000..a2c1faa7f --- /dev/null +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -0,0 +1,564 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + _prepare_text_with_image, +) +from fastchat.serve.gradio_block_arena_anony import ( + flash_buttons, + vote_last_response, + leftvote_last_response, + rightvote_last_response, + tievote_last_response, + bothbad_vote_last_response, + regenerate, + clear_history, + share_click, + add_text, + bot_response_multi, + set_global_vars_anony, + load_demo_side_by_side_anony, + get_sample_weight, + get_battle_pair, +) +from fastchat.serve.gradio_block_arena_vision import ( + get_vqa_sample, + set_invisible_image, + set_visible_image, + add_image, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + +# TODO(chris): fix sampling weights +SAMPLING_WEIGHTS = { + # tier 0 + "gpt-4-turbo": 4, + "gemini-1.5-pro-preview-0409": 4, + "gemini-1.0-pro-vision": 4, + "claude-3-opus-20240229": 4, + "claude-3-haiku-20240307": 4, + "claude-3-sonnet-20240229": 4, + "llava-v1.6-34b": 4, + "llava-v1.6-13b": 4, + "llava-v1.6-7b": 4, + "reka-flash-20240226": 4, +} + +# TODO(chris): Find battle targets that make sense +BATTLE_TARGETS = { + "gpt-4-turbo": { + "gemini-1.5-pro-preview-0409", + "claude-3-opus-20240229", + "reka-flash-20240226", + }, + "gemini-1.5-pro-preview-0409": { + "gpt-4-turbo", + "gemini-1.0-pro-vision", + "reka-flash-20240226", + }, + "gemini-1.0-pro-vision": { + "gpt-4-turbo", + "gemini-1.5-pro-preview-0409", + }, + "claude-3-opus-20240229": { + "gpt-4-turbo", + "gemini-1.5-pro-preview-0409", + "reka-flash-20240226", + }, + "claude-3-sonnet-20240229": { + "claude-3-opus-20240229", + "gpt-4-turbo", + "gemini-1.0-pro-vision", + "gemini-1.5-pro-preview-0409", + }, + "claude-3-haiku-20240307": { + "claude-3-opus-20240229", + "gpt-4-turbo", + "gemini-1.0-pro-vision", + "gemini-1.5-pro-preview-0409", + }, + "llava-v1.6-34b": { + "gpt-4-turbo", + "gemini-1.5-pro-preview-0409", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + }, + "llava-v1.6-13b": {"llava-v1.6-7b", "llava-v1.6-34b", "gemini-1.0-pro-vision"}, + "llava-v1.6-7b": {"llava-v1.6-13b", "gemini-1.0-pro-vision"}, + "reka-flash-20240226": { + "gemini-1.0-pro-vision", + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + }, +} + +# TODO(chris): Fill out models that require sampling boost +SAMPLING_BOOST_MODELS = [] + +# outage models won't be sampled. +OUTAGE_MODELS = [] + + +def load_demo_side_by_side_vision_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown(visible=True), + gr.Markdown(visible=True), + ) + + return states + selector_updates + + +def clear_history_example(request: gr.Request): + logger.info(f"clear_history_example (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + filename = get_conv_log_filename(states[0].is_vision) + + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + if ":" not in model_selectors[0]: + for i in range(5): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + (None,) + (disable_btn,) * 4 + time.sleep(0.1) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + (None,) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return ( + states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 + ) + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [None] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [""] + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request +): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + model_left, model_right = get_battle_pair( + models, + BATTLE_TARGETS, + OUTAGE_MODELS, + SAMPLING_WEIGHTS, + SAMPLING_BOOST_MODELS, + ) + states = [ + State(model_left, is_vision=True), + State(model_right, is_vision=True), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + flagged = moderation_filter(text, model_list) + if flagged: + logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": CONVERSATION_LIMIT_MSG}] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image(states[i], text, images) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + hint_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + hint_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + disable_btn, + ] + * 6 + + [hint_msg] + ) + + +def build_side_by_side_vision_ui_anony(models, random_questions=None): + notice_markdown = """ +# ⚔️ Vision Arena ⚔️: Benchmarking VLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Ask any question to two anonymous models (e.g., Claude, Gemini, GPT-4-V) and vote for the better one! +- You can continue chatting until you identify a winner. +- Vote won't be counted if model identity is revealed during conversation. + +## 👇 Chat now! +Note: You can only chat with **one image per conversation**. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Row(): + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + + with gr.Column(scale=5): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown( + model_description_md, elem_id="model_description_markdown" + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id="chatbot", + height=550, + show_copy_button=True, + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown( + anony_names[i], elem_id="model_selector_md" + ) + with gr.Row(): + slow_warning = gr.Markdown("", elem_id="notice_markdown") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + container=True, + placeholder="Click add or drop your image here", + elem_id="input_box", + ) + # send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, states + chatbots + model_selectors + btn_list) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list + [slow_warning], + ).then(set_invisible_image, [], [image_column]).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, states + chatbots + model_selectors + btn_list + ) + + return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py new file mode 100644 index 000000000..2a74849a7 --- /dev/null +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -0,0 +1,438 @@ +""" +Multimodal Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import os +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import ( + flash_buttons, + share_click, + bot_response_multi, +) +from fastchat.serve.gradio_block_arena_vision import ( + get_vqa_sample, + set_invisible_image, + set_visible_image, + add_image, +) +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + _prepare_text_with_image, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, +) + + +logger = build_logger( + "gradio_web_server_vision_multi", "gradio_web_server_vision_multi.log" +) + +num_sides = 2 +enable_moderation = False + + +def clear_history_example(request: gr.Request): + logger.info(f"clear_history_example (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + filename = get_conv_log_filename(states[0].is_vision) + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return ( + states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 + ) + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [None] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request +): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i], is_vision=True) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[0].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + flagged = moderation_filter(all_conv_text, model_list) + if flagged: + logger.info(f"violate moderation (named). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": CONVERSATION_LIMIT_MSG}] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image(states[i], text, images) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + disable_btn, + ] + * 6 + ) + + +def build_side_by_side_vision_ui_named(models, random_questions=None): + notice_markdown = """ +# ⚔️ Vision Arena ⚔️ : Benchmarking VLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Chat with any two models side-by-side and vote! +- You can continue chatting for multiple rounds. +- Click "Clear history" to start a new round. + +## 🤖 Choose two models to compare +Note: You can only chat with **one image per conversation**. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Row(): + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + + with gr.Column(scale=5): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown( + model_description_md, elem_id="model_description_markdown" + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id=f"chatbot", + height=550, + show_copy_button=True, + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Click add or drop your image here", + container=True, + elem_id="input_box", + ) + + with gr.Row() as button_row: + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, states + chatbots + btn_list) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then(set_invisible_image, [], [image_column]).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, states + chatbots + btn_list + ) + + return states + model_selectors From 0590b4cc8e3d751e490f9893b488d7ada1d4b63c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Tue, 28 May 2024 09:36:35 +0800 Subject: [PATCH 10/21] =?UTF-8?q?=E5=90=88=E5=B9=B6=E6=9C=80=E6=96=B0?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/train/train_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastchat/train/train_lora.py b/fastchat/train/train_lora.py index 9ecb47c29..824134a60 100644 --- a/fastchat/train/train_lora.py +++ b/fastchat/train/train_lora.py @@ -183,6 +183,7 @@ def train(): padding_side="right", use_fast=False, ) + tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.unk_token data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) From 8af0b9708c8022eb25a3ebdbb30a3a0b90d0c69c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Wed, 29 May 2024 19:22:12 +0800 Subject: [PATCH 11/21] =?UTF-8?q?=E5=A2=9E=E5=8A=A0tensorRT=E6=8E=A8?= =?UTF-8?q?=E7=90=86=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 2 +- fastchat/serve/trt_worker.py | 413 +++++++++++++++++++++++++++++++++++ 2 files changed, 414 insertions(+), 1 deletion(-) create mode 100644 fastchat/serve/trt_worker.py diff --git a/Dockerfile b/Dockerfile index 8886f4712..a833a7217 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,6 @@ WORKDIR /app COPY . /app/ RUN pip3 install -e . -RUN pip3 install pydantic==1.10.13 +RUN pip3 install pydantic CMD ["python3", "-m", "fastchat.serve.controller", "--host", "0.0.0.0"] \ No newline at end of file diff --git a/fastchat/serve/trt_worker.py b/fastchat/serve/trt_worker.py new file mode 100644 index 000000000..85ab65e09 --- /dev/null +++ b/fastchat/serve/trt_worker.py @@ -0,0 +1,413 @@ +""" +A model worker that executes the TensorRT engine. + +Refer to the implemention in https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/run.py +""" + +import argparse +import asyncio +import json +import os +from typing import List, Optional +from pathlib import Path + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +import re +import torch +from transformers import AutoTokenizer, T5Tokenizer + +import tensorrt_llm +from tensorrt_llm import runtime +from tensorrt_llm.logger import logger +from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner +from tensorrt_llm.builder import get_engine_version +from copy import deepcopy + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() + + +def read_model_name(engine_dir: str): + engine_version = get_engine_version(engine_dir) + + with open(Path(engine_dir) / "config.json", "r") as f: + config = json.load(f) + + if engine_version is None: + return config["builder_config"]["name"] + + return config["pretrained_config"]["architecture"] + + +def throttle_generator(generator, stream_interval): + for i, out in enumerate(generator): + if not i % stream_interval: + yield out + + if i % stream_interval: + yield out + + +def load_tokenizer( + tokenizer_dir: Optional[str] = None, + vocab_file: Optional[str] = None, + model_name: str = "gpt", + tokenizer_type: Optional[str] = None, +): + if vocab_file is None: + use_fast = True + if tokenizer_type is not None and tokenizer_type == "llama": + use_fast = False + # Should set both padding_side and truncation_side to be 'left' + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, + legacy=False, + padding_side="left", + truncation_side="left", + trust_remote_code=True, + tokenizer_type=tokenizer_type, + use_fast=use_fast, + ) + else: + # For gpt-next, directly load from tokenizer.model + assert model_name == "gpt" + tokenizer = T5Tokenizer( + vocab_file=vocab_file, padding_side="left", truncation_side="left" + ) + + if model_name == "qwen": + with open(Path(tokenizer_dir) / "generation_config.json") as f: + gen_config = json.load(f) + chat_format = gen_config["chat_format"] + if chat_format == "raw": + pad_id = gen_config["pad_token_id"] + end_id = gen_config["eos_token_id"] + elif chat_format == "chatml": + pad_id = tokenizer.im_end_id + end_id = tokenizer.im_end_id + else: + raise Exception(f"unknown chat format: {chat_format}") + elif model_name == "glm_10b": + pad_id = tokenizer.pad_token_id + end_id = tokenizer.eop_token_id + else: + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + pad_id = tokenizer.pad_token_id + end_id = tokenizer.eos_token_id + + return tokenizer, pad_id, end_id + + +class TensorRTWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + runner: ModelRunner, + tokenizer: AutoTokenizer, + pad_id: int, + end_id: int, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}." + f"worker type: tensorRT worker..." + ) + logger.info( + ( + "worker args:\n" + f"controller_addr: {controller_addr}\n" + f"worker_addr: {worker_addr}\n" + f"worker_id: {worker_id}\n" + f"model_path: {model_path}\n" + f"model_names: {model_names}\n" + f"limit_worker_concurrency: {limit_worker_concurrency}\n" + f"no_register: {no_register}\n" + f"conv_template: {conv_template}\n" + f"runner: {runner}\n" + f"tokenizer: {tokenizer}\n" + f"pad_id: {pad_id}\n" + f"end_id: {end_id}\n" + ) + ) + + self.runner = runner + self.tokenizer = tokenizer + self.pad_id = pad_id + self.end_id = end_id + self.context_len = get_context_length(self.runner.config) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + try: + + def generate(runner, tokenizer, params): + input_ids = tokenizer.encode( + params["prompt"], add_special_tokens=True, truncation=True + ) + batch_input_ids = [torch.tensor(input_ids, dtype=torch.int32)] + + max_new_tokens = int(params.get("max_new_tokens", 128)) + temperature = float(params.get("temperature", 0.7)) + top_k = int(params.get("top_k", -1)) + top_p = float(params.get("top_p", 1.0)) + + assert top_k != -1, "Top_k in TensorRT should not be -1" + + outputs = runner.generate( + batch_input_ids, + max_new_tokens=max_new_tokens, + end_id=self.end_id, + pad_id=self.pad_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_beams=1, + streaming=True, + output_sequence_lengths=True, + return_dict=True, + ) + torch.cuda.synchronize() + + input_lengths = [x.size(0) for x in batch_input_ids] + for curr_outputs in throttle_generator(outputs, 1): + if tensorrt_llm.mpi_rank() == 0: + output_ids = curr_outputs["output_ids"] + sequence_lengths = curr_outputs["sequence_lengths"] + batch_size, num_beams, _ = output_ids.size() + for batch_idx in range(0, batch_size): + for beam in range(num_beams): + output_begin = input_lengths[batch_idx] + output_end = sequence_lengths[batch_idx][beam].item() + outputs = output_ids[batch_idx][beam][ + output_begin:output_end + ].tolist() + output_text = tokenizer.decode(outputs) + response = output_text + yield { + "text": response, + "usage": { + "prompt_tokens": input_lengths[batch_idx], + "completion_tokens": output_end - output_begin, + "total_tokens": input_lengths[batch_idx] + + output_end + - output_begin, + }, + "finish_reason": None, + } + yield { + "text": response, + "usage": { + "prompt_tokens": input_lengths[0], + "completion_tokens": output_end - output_begin, + "total_tokens": input_lengths[0] + output_end - output_begin, + }, + "finish_reason": "stop", + } + + for output in generate(self.runner, self.tokenizer, params): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await worker.generate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--model-path", + type=str, + help="tensorRT engine path", + default="lmsys/vicuna-7b-v1.5", + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--tokenizer-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument("--limit-worker-concurrency", type=int, default=5) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--lora-dir", + type=str, + default=None, + nargs="+", + help="The directory of LoRA weights", + ) + parser.add_argument( + "--debug-mode", + default=False, + action="store_true", + help="Whether or not to turn on the debug mode", + ) + + args = parser.parse_args() + logger.info(f"args: {args}") + + # load tokenizer + model_name = read_model_name(args.model_path) + tokenizer, pad_id, end_id = load_tokenizer( + tokenizer_dir=args.tokenizer_path, model_name=model_name + ) + + # load trt runner + runtime_rank = tensorrt_llm.mpi_rank() + runner_cls = ModelRunner + runner_kwargs = dict( + engine_dir=args.model_path, + lora_dir=args.lora_dir, + rank=runtime_rank, + debug_mode=args.debug_mode, + ) + runner = runner_cls.from_dir(**runner_kwargs) + + # get config + config_path = os.path.join(args.model_path, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + runner.config = { + **config["build_config"], + **config["build_config"]["plugin_config"], + } + + # create worker + worker = TensorRTWorker( + controller_addr=args.controller_address, + worker_addr=args.worker_address, + worker_id=worker_id, + model_path=args.model_path, + model_names=args.model_names, + limit_worker_concurrency=args.limit_worker_concurrency, + no_register=args.no_register, + conv_template=args.conv_template, + runner=runner, + tokenizer=tokenizer, + pad_id=pad_id, + end_id=end_id, + ) + + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + uvicorn.run(app, host=args.host, port=args.port, log_level="info") From a707950ad3c2e6c32ace6b8b45ff6f556cb4c0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Fri, 21 Jun 2024 09:30:53 +0800 Subject: [PATCH 12/21] =?UTF-8?q?=E5=90=88=E5=B9=B6=E4=BA=86=E9=83=A8?= =?UTF-8?q?=E5=88=86pr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/protocol/openai_api_protocol.py | 1 + fastchat/serve/inference.py | 69 +++++++------- fastchat/serve/openai_api_server.py | 69 ++++++++------ fastchat/serve/trt_worker.py | 1 + fastchat/serve/vllm_worker.py | 115 ++++++++++++++++++----- 5 files changed, 171 insertions(+), 84 deletions(-) diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index bb50a5ef0..c35c0b732 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -166,6 +166,7 @@ class CompletionRequest(BaseModel): user: Optional[str] = None use_beam_search: Optional[bool] = False best_of: Optional[int] = None + seed: Optional[int] = None class CompletionResponseChoice(BaseModel): diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py index 6d155aab7..35afbe8cb 100644 --- a/fastchat/serve/inference.py +++ b/fastchat/serve/inference.py @@ -43,7 +43,7 @@ def prepare_logits_processor( - temperature: float, repetition_penalty: float, top_p: float, top_k: int + temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: processor_list = LogitsProcessorList() # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. @@ -60,13 +60,13 @@ def prepare_logits_processor( @torch.inference_mode() def generate_stream( - model, - tokenizer, - params: Dict, - device: str, - context_len: int, - stream_interval: int = 2, - judge_sent_end: bool = False, + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, ): if hasattr(model, "device"): device = model.device @@ -85,6 +85,9 @@ def generate_stream( stop_token_ids = params.get("stop_token_ids", None) or [] if tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(tokenizer.eos_token_id) + for item in model.generation_config.eos_token_id: + if item not in stop_token_ids: + stop_token_ids.append(item) logits_processor = prepare_logits_processor( temperature, repetition_penalty, top_p, top_k @@ -139,7 +142,7 @@ def generate_stream( shift_logits = logits[..., :-1, :].contiguous() shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() for label_id, logit in zip( - shift_input_ids[0].tolist(), shift_logits[0] + shift_input_ids[0].tolist(), shift_logits[0] ): token_logprobs.append(logit[label_id]) else: # decoding @@ -231,7 +234,7 @@ def generate_stream( if echo else token_logprobs[input_echo_len:], "top_logprobs": [{}] - * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), } # Compute text_offset curr_pos = 0 @@ -335,27 +338,27 @@ def print_output(self, text: str): def chat_loop( - model_path: str, - device: str, - num_gpus: int, - max_gpu_memory: str, - dtype: Optional[torch.dtype], - load_8bit: bool, - cpu_offloading: bool, - conv_template: Optional[str], - conv_system_msg: Optional[str], - temperature: float, - repetition_penalty: float, - max_new_tokens: int, - chatio: ChatIO, - gptq_config: Optional[GptqConfig] = None, - awq_config: Optional[AWQConfig] = None, - exllama_config: Optional[ExllamaConfig] = None, - xft_config: Optional[XftConfig] = None, - revision: str = "main", - judge_sent_end: bool = True, - debug: bool = True, - history: bool = True, + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, ): # Model model, tokenizer = load_model( @@ -401,7 +404,7 @@ def reload_conv(conv): """ Reprints the conversation from the start. """ - for message in conv.messages[conv.offset :]: + for message in conv.messages[conv.offset:]: chatio.prompt_for_output(message[0]) chatio.print_output(message[1]) @@ -483,7 +486,7 @@ def reload_conv(conv): # Check if file exists and add .json if needed if not os.path.exists(filename): if (not filename.endswith(".json")) and os.path.exists( - filename + ".json" + filename + ".json" ): filename += ".json" else: diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index a6ffee96b..1bcc66223 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -107,7 +107,7 @@ class AppSettings(BaseSettings): async def check_api_key( - auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), ) -> str: if app_settings.api_keys: if auth is None or (token := auth.credentials) not in app_settings.api_keys: @@ -154,7 +154,7 @@ async def check_model(request) -> Optional[JSONResponse]: async def check_length(request, prompt, max_tokens, worker_addr): if ( - not isinstance(max_tokens, int) or max_tokens <= 0 + not isinstance(max_tokens, int) or max_tokens <= 0 ): # model worker not support max_tokens=None max_tokens = 1024 * 1024 @@ -215,12 +215,19 @@ def check_requests(request) -> Optional[JSONResponse]: f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", ) if request.stop is not None and ( - not isinstance(request.stop, str) and not isinstance(request.stop, list) + not isinstance(request.stop, str) and not isinstance(request.stop, list) ): return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, f"{request.stop} is not valid under any of the given schemas - 'stop'", ) + if request.seed is not None and ( + not isinstance(request.seed, int) or request.seed < 0 + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.seed} is not a nonnegative integer", + ) return None @@ -264,21 +271,22 @@ def _add_to_set(s, new_stop): async def get_gen_params( - model_name: str, - worker_addr: str, - messages: Union[str, List[Dict[str, str]]], - *, - temperature: float, - top_p: float, - top_k: Optional[int], - presence_penalty: Optional[float], - frequency_penalty: Optional[float], - max_tokens: Optional[int], - echo: Optional[bool], - logprobs: Optional[int] = None, - stop: Optional[Union[str, List[str]]], - best_of: Optional[int] = None, - use_beam_search: Optional[bool] = None, + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, + seed: Optional[int] = None, ) -> Dict[str, Any]: conv = await get_conv(model_name, worker_addr) conv = Conversation( @@ -344,6 +352,7 @@ async def get_gen_params( "max_new_tokens": max_tokens, "echo": echo, "stop_token_ids": conv.stop_token_ids, + "seed": seed, } if len(images) > 0: @@ -432,6 +441,7 @@ async def create_chat_completion(request: ChatCompletionRequest): max_tokens=request.max_tokens, echo=False, stop=request.stop, + seed=request.seed, ) max_new_tokens, error_check_ret = await check_length( @@ -484,7 +494,7 @@ async def create_chat_completion(request: ChatCompletionRequest): async def chat_completion_stream_generator( - model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str ) -> Generator[str, Any, None]: """ Event stream format: @@ -511,7 +521,7 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" return decoded_unicode = content["text"].replace("\ufffd", "") - delta_text = decoded_unicode[len(previous_text) :] + delta_text = decoded_unicode[len(previous_text):] previous_text = ( decoded_unicode if len(decoded_unicode) > len(previous_text) @@ -619,7 +629,7 @@ async def create_completion(request: CompletionRequest): async def generate_completion_stream_generator( - request: CompletionRequest, n: int, worker_addr: str + request: CompletionRequest, n: int, worker_addr: str ): model_name = request.model id = f"cmpl-{shortuuid.random()}" @@ -640,6 +650,7 @@ async def generate_completion_stream_generator( logprobs=request.logprobs, echo=request.echo, stop=request.stop, + seed=request.seed, ) async for content in generate_completion_stream(gen_params, worker_addr): if content["error_code"] != 0: @@ -647,7 +658,7 @@ async def generate_completion_stream_generator( yield "data: [DONE]\n\n" return decoded_unicode = content["text"].replace("\ufffd", "") - delta_text = decoded_unicode[len(previous_text) :] + delta_text = decoded_unicode[len(previous_text):] previous_text = ( decoded_unicode if len(decoded_unicode) > len(previous_text) @@ -682,18 +693,18 @@ async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): async with httpx.AsyncClient() as client: delimiter = b"\0" async with client.stream( - "POST", - worker_addr + "/worker_generate_stream", - headers=headers, - json=payload, - timeout=WORKER_API_TIMEOUT, + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, ) as response: # content = await response.aread() buffer = b"" async for raw_chunk in response.aiter_raw(): buffer += raw_chunk while (chunk_end := buffer.find(delimiter)) >= 0: - chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1:] if not chunk: continue yield json.loads(chunk.decode()) @@ -719,7 +730,7 @@ async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): token_num = 0 batch_size = WORKER_API_EMBEDDING_BATCH_SIZE batches = [ - request.input[i : min(i + batch_size, len(request.input))] + request.input[i: min(i + batch_size, len(request.input))] for i in range(0, len(request.input), batch_size) ] for num_batch, batch in enumerate(batches): diff --git a/fastchat/serve/trt_worker.py b/fastchat/serve/trt_worker.py index 85ab65e09..9d4bbc5f5 100644 --- a/fastchat/serve/trt_worker.py +++ b/fastchat/serve/trt_worker.py @@ -180,6 +180,7 @@ def generate(runner, tokenizer, params): temperature = float(params.get("temperature", 0.7)) top_k = int(params.get("top_k", -1)) top_p = float(params.get("top_p", 1.0)) + top_k = 0 if top_k == -1 else top_k assert top_k != -1, "Top_k in TensorRT should not be -1" diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..a402fa3e4 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -6,12 +6,14 @@ import argparse import asyncio +import codecs import json -from typing import List +from typing import List, Optional from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse import uvicorn +from transformers import GenerationConfig from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import SamplingParams @@ -24,22 +26,21 @@ ) from fastchat.utils import get_context_length, is_partial_stop - app = FastAPI() class VLLMWorker(BaseModelWorker): def __init__( - self, - controller_addr: str, - worker_addr: str, - worker_id: str, - model_path: str, - model_names: List[str], - limit_worker_concurrency: int, - no_register: bool, - llm_engine: AsyncLLMEngine, - conv_template: str, + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: AsyncLLMEngine, + conv_template: str, ): super().__init__( controller_addr, @@ -59,11 +60,45 @@ def __init__( # and llm_engine.engine.tokenizer was no longer a raw tokenizer if hasattr(self.tokenizer, "tokenizer"): self.tokenizer = llm_engine.engine.tokenizer.tokenizer + self._load_chat_template(chat_template=None) + try: + self.generation_config = GenerationConfig.from_pretrained( + model_path, trust_remote_code=True + ) + except Exception: + self.generation_config = None self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) if not no_register: self.init_heart_beat() + def _load_chat_template(self, chat_template: Optional[str]): + tokenizer = self.tokenizer + + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", tokenizer.chat_template) + else: + tokenizer.chat_template = "" + async def generate_stream(self, params): self.call_ct += 1 @@ -82,6 +117,7 @@ async def generate_stream(self, params): echo = params.get("echo", True) use_beam_search = params.get("use_beam_search", False) best_of = params.get("best_of", None) + seed = params.get("seed", None) request = params.get("request", None) @@ -115,6 +151,7 @@ async def generate_stream(self, params): presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, best_of=best_of, + seed=seed, ) results_generator = engine.generate(context, sampling_params, request_id) @@ -169,6 +206,28 @@ async def generate_stream(self, params): if aborted: break + def get_conv_template(self): + if self.tokenizer.chat_template: + chat_template_kwargs = { + "chat_template": { + "chat_template": self.tokenizer.chat_template, + "eos_token": self.tokenizer.eos_token, + "generation_config": self.generation_config.to_diff_dict() + if self.generation_config + else None, + } + } + else: + chat_template_kwargs = {} + + return { + "conv": self.conv, + **chat_template_kwargs, + } + + def apply_chat_template(self, params): + return self.tokenizer.apply_chat_template(**params) + async def generate(self, params): async for x in self.generate_stream(params): pass @@ -195,16 +254,28 @@ async def abort_request() -> None: return background_tasks +@app.post("/apply_chat_template") +async def api_apply_chat_template(request: Request): + params = await request.json() + prompt = worker.apply_chat_template(params) + return JSONResponse({"prompt": prompt}) + + @app.post("/worker_generate_stream") async def api_generate_stream(request: Request): params = await request.json() await acquire_worker_semaphore() request_id = random_uuid() - params["request_id"] = request_id - params["request"] = request - generator = worker.generate_stream(params) - background_tasks = create_background_tasks(request_id) - return StreamingResponse(generator, background=background_tasks) + try: + params["request_id"] = request_id + params["request"] = request + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + except Exception as e: + background_tasks = create_background_tasks(request_id) + await background_tasks() + raise e @app.post("/worker_generate") @@ -266,17 +337,17 @@ async def api_model_details(request: Request): action="store_false", default=True, help="Trust remote code (e.g., from HuggingFace) when" - "downloading the model and tokenizer.", + "downloading the model and tokenizer.", ) parser.add_argument( "--gpu_memory_utilization", type=float, default=0.9, help="The ratio (between 0 and 1) of GPU memory to" - "reserve for the model weights, activations, and KV cache. Higher" - "values will increase the KV cache size and thus improve the model's" - "throughput. However, if the value is too high, it may cause out-of-" - "memory (OOM) errors.", + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", ) parser = AsyncEngineArgs.add_cli_args(parser) From 0eb1cda41ca29c12f3755fc806d399aa9d987b16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 8 Jul 2024 11:38:16 +0800 Subject: [PATCH 13/21] =?UTF-8?q?=E5=A2=9E=E5=8A=A0worker=20info=20?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/controller.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py index 42d928403..936a32a67 100644 --- a/fastchat/serve/controller.py +++ b/fastchat/serve/controller.py @@ -27,7 +27,6 @@ ) from fastchat.utils import build_logger - logger = build_logger("controller", "controller.log") @@ -73,11 +72,11 @@ def __init__(self, dispatch_method: str): self.heart_beat_thread.start() def register_worker( - self, - worker_name: str, - check_heart_beat: bool, - worker_status: dict, - multimodal: bool, + self, + worker_name: str, + check_heart_beat: bool, + worker_status: dict, + multimodal: bool, ): if worker_name not in self.worker_info: logger.info(f"Register a new worker: {worker_name}") @@ -123,7 +122,7 @@ def refresh_all_workers(self): for w_name, w_info in old_info.items(): if not self.register_worker( - w_name, w_info.check_heart_beat, None, w_info.multimodal + w_name, w_info.check_heart_beat, None, w_info.multimodal ): logger.info(f"Remove stale worker: {w_name}") @@ -263,6 +262,9 @@ def worker_api_get_status(self): "queue_length": queue_length, } + def worker_get_info(self): + return self.worker_info + def worker_api_generate_stream(self, params): worker_addr = self.get_worker_address(params["model"]) if not worker_addr: @@ -350,6 +352,11 @@ async def worker_api_get_status(request: Request): return "success" +@app.get("/worker_get_info") +async def worker_api_get_status(request: Request): + return controller.worker_get_info() + + def create_controller(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") From 140067db1b783491f077466f041d2e770cb93eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 8 Jul 2024 14:30:37 +0800 Subject: [PATCH 14/21] =?UTF-8?q?=E5=A2=9E=E5=8A=A0worker=20info=20?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/controller.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py index 936a32a67..4f782ffdf 100644 --- a/fastchat/serve/controller.py +++ b/fastchat/serve/controller.py @@ -263,7 +263,15 @@ def worker_api_get_status(self): } def worker_get_info(self): - return self.worker_info + worker_info = self.worker_info + for w_name in worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + worker_info[w_name].model_names = worker_status["model_names"] + worker_info[w_name].speed = worker_status["speed"] + worker_info[w_name].queue_length = worker_status["queue_length"] + + return worker_info def worker_api_generate_stream(self, params): worker_addr = self.get_worker_address(params["model"]) From dcf070b1cb588130e3e48a8cf90a3f2869c45b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 22 Jul 2024 15:14:35 +0800 Subject: [PATCH 15/21] =?UTF-8?q?vllm=20=E6=94=AF=E6=8C=81=E5=8A=A0?= =?UTF-8?q?=E8=BD=BDlora=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/vllm_worker.py | 85 ++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index a402fa3e4..a1df161a0 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -8,6 +8,7 @@ import asyncio import codecs import json +from os import path from typing import List, Optional from fastapi import FastAPI, Request, BackgroundTasks @@ -26,6 +27,21 @@ ) from fastchat.utils import get_context_length, is_partial_stop +# Add imports for vLLM LoRAs, prevent panic with older vllm versions which not support LoRAs +# LoRA request only supports vLLM versions >= v0.3.2 +try: + from vllm.entrypoints.openai.serving_engine import LoRA + from vllm.lora.request import LoRARequest + + VLLM_LORA_SUPPORTED = True +except: + VLLM_LORA_SUPPORTED = False + + + # Fake LoRA class to compatible with old vLLM versions + class LoRA: + pass + app = FastAPI() @@ -41,7 +57,20 @@ def __init__( no_register: bool, llm_engine: AsyncLLMEngine, conv_template: str, + lora_modules: List[LoRA] = [], ): + # Register LoRA model names + if VLLM_LORA_SUPPORTED: + # If model_names defined, use basename of model path by default + model_names = ( + [path.basename(path.normpath(model_path))] + if model_names is None + else model_names + ) + if lora_modules: + lora_model_names = [lora.name for lora in lora_modules] + model_names += lora_model_names + super().__init__( controller_addr, worker_addr, @@ -69,6 +98,20 @@ def __init__( self.generation_config = None self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) + # Add LoRA requests, lora request will be forwarded to vLLM engine for generating with specific LoRA weights + self.lora_requests = ( + [ + LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_local_path=lora.local_path, + ) + for i, lora in enumerate(lora_modules, start=1) + ] + if VLLM_LORA_SUPPORTED and lora_modules + else [] + ) + if not no_register: self.init_heart_beat() @@ -99,6 +142,12 @@ def _load_chat_template(self, chat_template: Optional[str]): else: tokenizer.chat_template = "" + def get_model_lora_request(self, model_name): + for lora_req in self.lora_requests: + if lora_req.lora_name == model_name: + return lora_req + return None + async def generate_stream(self, params): self.call_ct += 1 @@ -153,7 +202,14 @@ async def generate_stream(self, params): best_of=best_of, seed=seed, ) - results_generator = engine.generate(context, sampling_params, request_id) + + if VLLM_LORA_SUPPORTED: + lora_request = self.get_model_lora_request(params.get("model")) + results_generator = engine.generate( + context, sampling_params, request_id, lora_request=lora_request + ) + else: + results_generator = engine.generate(context, sampling_params, request_id) async for request_output in results_generator: prompt = request_output.prompt @@ -312,6 +368,22 @@ async def api_model_details(request: Request): return {"context_length": worker.context_len} +# Add LoRAParserAction for supporting vLLM Multi-LoRA +class LoRAParserAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + if VLLM_LORA_SUPPORTED is False: + logger.warning( + "To use the vLLM LoRAs feature, please upgrade vLLM to version v0.3.2 or higher." + ) + return + + lora_list = [] + for item in values: + name, path = item.split("=") + lora_list.append(LoRA(name, path)) + setattr(namespace, self.dest, lora_list) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") @@ -350,6 +422,16 @@ async def api_model_details(request: Request): "memory (OOM) errors.", ) + # Support parse LoRA modules + parser.add_argument( + "--lora-modules", + type=str, + default=None, + nargs="+", + action=LoRAParserAction, + help="LoRA module configurations in the format name=path. Multiple modules can be specified.", + ) + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.model_path: @@ -369,5 +451,6 @@ async def api_model_details(request: Request): args.no_register, engine, args.conv_template, + args.lora_modules, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") From 15c04a28ec06ce644edb30003907e4cdffc4754c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Wed, 24 Jul 2024 18:59:52 +0800 Subject: [PATCH 16/21] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E9=99=90=E6=B5=81=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/vllm_worker.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index a1df161a0..31b4d9f2a 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -8,9 +8,11 @@ import asyncio import codecs import json +import time from os import path from typing import List, Optional +import requests from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse import uvicorn @@ -45,6 +47,32 @@ class LoRA: app = FastAPI() +# 定义一个新的 get 函数,添加限速功能 +def limited_get(url, stream=False, **kwargs): + response = original_get(url, stream=stream, **kwargs) + if not stream: + return response + + # 如果是流式下载,则对内容进行限速处理 + chunk_size = 1024 # 每次读取的块大小(字节) + max_speed = 20000 # 最大下载速度(KB/s) + + def generate(): + for chunk in response.iter_content(chunk_size=chunk_size): + yield chunk + time.sleep(chunk_size / (max_speed * 1024)) + + response.iter_content = generate + return response + + +# 保存原始的 requests.get 函数 +original_get = requests.get + +# 替换 requests.get 为新的函数 +requests.get = limited_get + + class VLLMWorker(BaseModelWorker): def __init__( self, From dcd4e6e830dec3660ad6c329b0add4d8fb6bcfb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 29 Jul 2024 13:40:51 +0800 Subject: [PATCH 17/21] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dvllm=20adapter=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/vllm_worker.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 31b4d9f2a..9f1935234 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -31,18 +31,17 @@ # Add imports for vLLM LoRAs, prevent panic with older vllm versions which not support LoRAs # LoRA request only supports vLLM versions >= v0.3.2 -try: - from vllm.entrypoints.openai.serving_engine import LoRA - from vllm.lora.request import LoRARequest +from vllm.lora.request import LoRARequest - VLLM_LORA_SUPPORTED = True -except: - VLLM_LORA_SUPPORTED = False +VLLM_LORA_SUPPORTED = True - # Fake LoRA class to compatible with old vLLM versions - class LoRA: - pass +# Fake LoRA class to compatible with old vLLM versions +class LoRA: + def __init__(self, name: str, local_path: str): + self.name: str = name + self.local_path: str = local_path + app = FastAPI() From 260eff89f88604e36d020ea72638c3bb87f2e1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Tue, 30 Jul 2024 15:31:37 +0800 Subject: [PATCH 18/21] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dvllm=20adapter=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/vllm_worker.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 9f1935234..f5c8e7694 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -448,6 +448,12 @@ def __call__(self, parser, namespace, values, option_string=None): "throughput. However, if the value is too high, it may cause out-of-" "memory (OOM) errors.", ) + parser.add_argument( + "--max-model-len", + type=float, + default=None, + help="Model context length. If unspecified, will be automatically derived from the model config.", + ) # Support parse LoRA modules parser.add_argument( From 24287bc172436ef33a9dc993a7666b6e42ddb1b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 19 Aug 2024 19:20:53 +0800 Subject: [PATCH 19/21] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dvllm=20adapter=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/serve/vllm_worker.py | 46 +++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 697f277d1..fa7d45eae 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -21,13 +21,15 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.inputs import TextPrompt +from fastchat.conversation import IMAGE_PLACEHOLDER_STR from fastchat.serve.base_model_worker import BaseModelWorker from fastchat.serve.model_worker import ( logger, worker_id, ) -from fastchat.utils import get_context_length, is_partial_stop +from fastchat.utils import get_context_length, is_partial_stop, load_image # Add imports for vLLM LoRAs, prevent panic with older vllm versions which not support LoRAs # LoRA request only supports vLLM versions >= v0.3.2 @@ -72,6 +74,20 @@ def generate(): requests.get = limited_get +def replace_placeholders_with_images(prompt: str, placeholder: str, images: List[str]): + """ + 将多个占位符替换为实际的图片 URL。 + + :param prompt: 包含占位符的原始提示字符串 + :param placeholder: 要替换的占位符 + :param images: 替换占位符的实际图片 列表 + :return: 替换后的提示字符串 + """ + for img in images: + prompt = prompt.replace(placeholder, img, 1) # 只替换第一个出现的占位符 + return prompt + + class VLLMWorker(BaseModelWorker): def __init__( self, @@ -178,7 +194,8 @@ def get_model_lora_request(self, model_name): async def generate_stream(self, params): self.call_ct += 1 - context = params.pop("prompt") + prompt = params.pop("prompt") + images = params.get("images", []) request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) @@ -197,6 +214,25 @@ async def generate_stream(self, params): request = params.get("request", None) + # split prompt by image token + split_prompt = prompt.split("") + if prompt.count("") != len(images): + raise ValueError( + "The number of images passed in does not match the number of tokens in the prompt!" + ) + + # context: List[TextPrompt] = [] + # for i in range(len(split_prompt)): + # img = "" + # if i < len(images): + # img = load_image(images[i]) + # context.append({"prompt": split_prompt[i], "multi_modal_data": {"image": img}}) + context: TextPrompt = { + "prompt": prompt, + } + if len(images) > 0: + context["multi_modal_data"] = {"image": load_image(images[0])}, + # Handle stop_str stop = set() if isinstance(stop_str, str) and stop_str != "": @@ -448,12 +484,6 @@ def __call__(self, parser, namespace, values, option_string=None): "throughput. However, if the value is too high, it may cause out-of-" "memory (OOM) errors.", ) - parser.add_argument( - "--max-model-len", - type=float, - default=None, - help="Model context length. If unspecified, will be automatically derived from the model config.", - ) # Support parse LoRA modules parser.add_argument( From d37b1ee9f93832cc6a016e83529befb69e59e34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Wed, 25 Sep 2024 16:02:05 +0800 Subject: [PATCH 20/21] skip_special_tokens --- fastchat/serve/vllm_worker.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index cb5d25686..89fd37977 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -211,15 +211,16 @@ async def generate_stream(self, params): use_beam_search = params.get("use_beam_search", False) best_of = params.get("best_of", None) seed = params.get("seed", None) + skip_special_tokens = params.get("skip_special_tokens", True) request = params.get("request", None) # split prompt by image token - split_prompt = prompt.split("") - if prompt.count("") != len(images): - raise ValueError( - "The number of images passed in does not match the number of tokens in the prompt!" - ) + # split_prompt = prompt.split("") + # if prompt.count("") != len(images): + # raise ValueError( + # "The number of images passed in does not match the number of tokens in the prompt!" + # ) # context: List[TextPrompt] = [] # for i in range(len(split_prompt)): @@ -264,6 +265,7 @@ async def generate_stream(self, params): frequency_penalty=frequency_penalty, best_of=best_of, seed=seed, + skip_special_tokens=skip_special_tokens ) if VLLM_LORA_SUPPORTED: From 89599712dac07431cfd5548f532b8258f74ed3c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=81=AA?= Date: Mon, 30 Sep 2024 18:31:31 +0800 Subject: [PATCH 21/21] =?UTF-8?q?=E4=B8=B4=E6=97=B6=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastchat/conversation.py | 2 +- fastchat/protocol/openai_api_protocol.py | 1 + fastchat/serve/openai_api_server.py | 3 ++ fastchat/serve/vllm_worker.py | 54 +++++++++++++++++++++--- 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 41b444ff1..da945932e 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -343,7 +343,7 @@ def get_images(self): if i % 2 == 0: if type(msg) is tuple: for image in msg[1]: - images.append(image.base64_str) + images.append(image) return images diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index c35c0b732..85410079b 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel): presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None + seed: Optional[int] = None class ChatMessage(BaseModel): diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 1bcc66223..447d32813 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -428,6 +428,7 @@ async def create_chat_completion(request: ChatCompletionRequest): return error_check_ret worker_addr = await get_worker_address(request.model) + logger.info(f"worker_addr: {worker_addr}") gen_params = await get_gen_params( request.model, @@ -444,6 +445,8 @@ async def create_chat_completion(request: ChatCompletionRequest): seed=request.seed, ) + print(gen_params["prompt"]) + max_new_tokens, error_check_ret = await check_length( request, gen_params["prompt"], diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 89fd37977..51b5a5b59 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -22,6 +22,9 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from vllm.inputs import TextPrompt +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.utils import FlexibleArgumentParser from fastchat.conversation import IMAGE_PLACEHOLDER_STR from fastchat.serve.base_model_worker import BaseModelWorker @@ -88,6 +91,39 @@ def replace_placeholders_with_images(prompt: str, placeholder: str, images: List return prompt +def get_multi_modal_input(args): + """ + return { + "data": image or video, + "question": question, + } + """ + if args.modality == "image": + # Input image and question + image = ImageAsset("cherry_blossom") \ + .pil_image.convert("RGB") + img_question = "What is the content of this image?" + + return { + "data": image, + "question": img_question, + } + + if args.modality == "video": + # Input video and question + video = VideoAsset(name="sample_demo_1.mp4", + num_frames=args.num_frames).np_ndarrays + vid_question = "Why is this video funny?" + + return { + "data": video, + "question": vid_question, + } + + msg = f"Modality {args.modality} is not supported." + raise ValueError(msg) + + class VLLMWorker(BaseModelWorker): def __init__( self, @@ -214,13 +250,18 @@ async def generate_stream(self, params): skip_special_tokens = params.get("skip_special_tokens", True) request = params.get("request", None) + image_token = params.get("image_token", IMAGE_PLACEHOLDER_STR) + + if images is None: + images = [] + # split prompt by image token - # split_prompt = prompt.split("") - # if prompt.count("") != len(images): - # raise ValueError( - # "The number of images passed in does not match the number of tokens in the prompt!" - # ) + # split_prompt = prompt.split(IMAGE_PLACEHOLDER_STR) + if prompt.count(image_token) != len(images): + raise ValueError( + "The number of images passed in does not match the number of tokens in the prompt!" + ) # context: List[TextPrompt] = [] # for i in range(len(split_prompt)): @@ -232,7 +273,8 @@ async def generate_stream(self, params): "prompt": prompt, } if len(images) > 0: - context["multi_modal_data"] = {"image": load_image(images[0])}, + # context["multi_modal_data"] = {"image": [load_image(url) for url in images]}, + context["multi_modal_data"] = {"image": [load_image(url) for url in images]}, # Handle stop_str stop = set()