diff --git a/chatglm_cpp/convert.py b/chatglm_cpp/convert.py index 32275b9..b9304ac 100644 --- a/chatglm_cpp/convert.py +++ b/chatglm_cpp/convert.py @@ -1,6 +1,7 @@ """ Convert Hugging Face ChatGLM/ChatGLM2 models to GGML format """ + import argparse import platform import struct @@ -534,7 +535,7 @@ def main(): args = parser.parse_args() with open(args.save_path, "wb") as f: - convert(f, args.model_name_or_path, dtype=args.type) + convert(f, args.model_name_or_path, args.lora_model_name_or_path, dtype=args.type) print(f"GGML model saved to {args.save_path}") diff --git a/chatglm_cpp/openai_api.py b/chatglm_cpp/openai_api.py index 2a31008..42c0a47 100644 --- a/chatglm_cpp/openai_api.py +++ b/chatglm_cpp/openai_api.py @@ -63,7 +63,9 @@ class ChatCompletionRequest(BaseModel): tools: Optional[List[ChatCompletionTool]] = None model_config = { - "json_schema_extra": {"examples": [{"model": "default-model", "messages": [{"role": "user", "content": "你好"}]}]} + "json_schema_extra": { + "examples": [{"model": "default-model", "messages": [{"role": "user", "content": "你好"}]}] + } } @@ -108,7 +110,10 @@ class ChatCompletionResponse(BaseModel): "choices": [ { "index": 0, - "message": {"role": "assistant", "content": "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。"}, + "message": { + "role": "assistant", + "content": "你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。", + }, "finish_reason": "stop", } ], diff --git a/tests/test_chatglm_cpp.py b/tests/test_chatglm_cpp.py index ae1bb37..ba36cf5 100644 --- a/tests/test_chatglm_cpp.py +++ b/tests/test_chatglm_cpp.py @@ -37,17 +37,29 @@ def check_pipeline(model_path, prompt, target, gen_kwargs={}): @pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found") def test_chatglm_pipeline(): - check_pipeline(model_path=CHATGLM_MODEL_PATH, prompt="你好", target="你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。") + check_pipeline( + model_path=CHATGLM_MODEL_PATH, + prompt="你好", + target="你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。", + ) @pytest.mark.skipif(not CHATGLM2_MODEL_PATH.exists(), reason="model file not found") def test_chatglm2_pipeline(): - check_pipeline(model_path=CHATGLM2_MODEL_PATH, prompt="你好", target="你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。") + check_pipeline( + model_path=CHATGLM2_MODEL_PATH, + prompt="你好", + target="你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。", + ) @pytest.mark.skipif(not CHATGLM3_MODEL_PATH.exists(), reason="model file not found") def test_chatglm3_pipeline(): - check_pipeline(model_path=CHATGLM3_MODEL_PATH, prompt="你好", target="你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。") + check_pipeline( + model_path=CHATGLM3_MODEL_PATH, + prompt="你好", + target="你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。", + ) @pytest.mark.skipif(not CODEGEEX2_MODEL_PATH.exists(), reason="model file not found")