diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 63a5f9627..53fbc5c5b 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -840,6 +840,19 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Mistral template +# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template +register_conv_template( + Conversation( + name="mistral", + system_template="", + roles=("[INST] ", " [/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep="", + sep2=" ", + ) +) + # llama2 template # reference: https://huggingface.co/blog/codellama#conversational-instructions # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index d2ac56f8d..db9da37b7 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1256,6 +1256,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("starchat") +class MistralAdapter(BaseModelAdapter): + """The model adapter for Mistral AI models""" + + def match(self, model_path: str): + return "mistral" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("mistral") + + class Llama2Adapter(BaseModelAdapter): """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)""" @@ -1653,6 +1669,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(InternLMChatAdapter) register_model_adapter(StarChatAdapter) register_model_adapter(Llama2Adapter) +register_model_adapter(MistralAdapter) register_model_adapter(CuteGPTAdapter) register_model_adapter(OpenOrcaAdapter) register_model_adapter(WizardCoderAdapter)