From b035c02f78924020e2329de36b473189851bb409 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 23 Aug 2024 23:52:25 +0800 Subject: [PATCH] chore(api/tests): apply ruff reformat #7590 (#7591) Co-authored-by: -LAN- --- api/pyproject.toml | 1 - .../model_runtime/__mock/anthropic.py | 68 +-- .../model_runtime/__mock/google.py | 58 +- .../model_runtime/__mock/huggingface.py | 7 +- .../model_runtime/__mock/huggingface_chat.py | 27 +- .../model_runtime/__mock/huggingface_tei.py | 42 +- .../model_runtime/__mock/openai.py | 21 +- .../model_runtime/__mock/openai_chat.py | 178 ++++--- .../model_runtime/__mock/openai_completion.py | 77 +-- .../model_runtime/__mock/openai_embeddings.py | 50 +- .../model_runtime/__mock/openai_moderation.py | 102 ++-- .../model_runtime/__mock/openai_remote.py | 11 +- .../__mock/openai_speech2text.py | 19 +- .../model_runtime/__mock/xinference.py | 135 +++-- .../model_runtime/anthropic/test_llm.py | 75 +-- .../model_runtime/anthropic/test_provider.py | 12 +- .../model_runtime/azure_openai/test_llm.py | 265 ++++------ .../azure_openai/test_text_embedding.py | 49 +- .../model_runtime/baichuan/test_llm.py | 132 ++--- .../model_runtime/baichuan/test_provider.py | 12 +- .../baichuan/test_text_embedding.py | 42 +- .../model_runtime/bedrock/test_llm.py | 66 +-- .../model_runtime/bedrock/test_provider.py | 6 +- .../model_runtime/chatglm/test_llm.py | 215 +++----- .../model_runtime/chatglm/test_provider.py | 14 +- .../model_runtime/cohere/test_llm.py | 183 ++----- .../model_runtime/cohere/test_provider.py | 10 +- .../model_runtime/cohere/test_rerank.py | 26 +- .../cohere/test_text_embedding.py | 38 +- .../model_runtime/google/test_llm.py | 185 +++---- .../model_runtime/google/test_provider.py | 12 +- .../model_runtime/huggingface_hub/test_llm.py | 260 ++++----- .../huggingface_hub/test_text_embedding.py | 89 ++-- .../huggingface_tei/test_embeddings.py | 38 +- .../huggingface_tei/test_rerank.py | 38 +- .../model_runtime/hunyuan/test_llm.py | 69 +-- .../model_runtime/hunyuan/test_provider.py | 11 +- .../hunyuan/test_text_embedding.py | 50 +- .../model_runtime/jina/test_provider.py | 12 +- .../model_runtime/jina/test_text_embedding.py | 32 +- .../model_runtime/localai/test_embedding.py | 6 +- .../model_runtime/localai/test_llm.py | 154 ++---- .../model_runtime/localai/test_rerank.py | 56 +- .../model_runtime/localai/test_speech2text.py | 30 +- .../model_runtime/minimax/test_embedding.py | 41 +- .../model_runtime/minimax/test_llm.py | 107 ++-- .../model_runtime/minimax/test_provider.py | 8 +- .../model_runtime/novita/test_llm.py | 70 +-- .../model_runtime/novita/test_provider.py | 6 +- .../model_runtime/ollama/test_llm.py | 190 +++---- .../ollama/test_text_embedding.py | 48 +- .../model_runtime/openai/test_llm.py | 294 ++++------- .../model_runtime/openai/test_moderation.py | 33 +- .../model_runtime/openai/test_provider.py | 12 +- .../model_runtime/openai/test_speech2text.py | 35 +- .../openai/test_text_embedding.py | 45 +- .../openai_api_compatible/test_llm.py | 149 +++--- .../openai_api_compatible/test_speech2text.py | 17 +- .../test_text_embedding.py | 52 +- .../model_runtime/openllm/test_embedding.py | 33 +- .../model_runtime/openllm/test_llm.py | 69 ++- .../model_runtime/openrouter/test_llm.py | 71 +-- .../model_runtime/replicate/test_llm.py | 74 ++- .../replicate/test_text_embedding.py | 95 ++-- .../model_runtime/sagemaker/test_provider.py | 8 +- .../model_runtime/sagemaker/test_rerank.py | 12 +- .../sagemaker/test_text_embedding.py | 32 +- .../model_runtime/siliconflow/test_llm.py | 71 +-- .../siliconflow/test_provider.py | 10 +- .../model_runtime/siliconflow/test_rerank.py | 12 +- .../siliconflow/test_speech2text.py | 16 +- .../siliconflow/test_text_embedding.py | 4 +- .../model_runtime/spark/test_llm.py | 77 +-- .../model_runtime/spark/test_provider.py | 10 +- .../model_runtime/stepfun/test_llm.py | 123 ++--- .../test_model_provider_factory.py | 29 +- .../model_runtime/togetherai/test_llm.py | 74 +-- .../model_runtime/tongyi/test_llm.py | 67 +-- .../model_runtime/tongyi/test_provider.py | 8 +- .../tongyi/test_response_format.py | 18 +- .../model_runtime/upstage/test_llm.py | 177 +++---- .../model_runtime/upstage/test_provider.py | 12 +- .../upstage/test_text_embedding.py | 37 +- .../volcengine_maas/test_embedding.py | 70 ++- .../model_runtime/volcengine_maas/test_llm.py | 113 ++-- .../model_runtime/wenxin/test_embedding.py | 44 +- .../model_runtime/wenxin/test_llm.py | 195 +++---- .../model_runtime/wenxin/test_provider.py | 12 +- .../xinference/test_embeddings.py | 46 +- .../model_runtime/xinference/test_llm.py | 201 +++---- .../model_runtime/xinference/test_rerank.py | 32 +- .../model_runtime/zhinao/test_llm.py | 71 +-- .../model_runtime/zhinao/test_provider.py | 10 +- .../model_runtime/zhipuai/test_llm.py | 108 ++-- .../model_runtime/zhipuai/test_provider.py | 10 +- .../zhipuai/test_text_embedding.py | 38 +- .../integration_tests/tools/__mock/http.py | 15 +- .../tools/__mock_server/openapi_todo.py | 6 +- .../tools/api_tool/test_api_tool.py | 51 +- .../tools/test_all_provider.py | 9 +- .../integration_tests/utils/parent_class.py | 2 +- .../utils/test_module_import_helper.py | 24 +- .../vdb/__mock/tcvectordb.py | 143 +++-- .../vdb/analyticdb/test_analyticdb.py | 5 +- .../vdb/chroma/test_chroma.py | 4 +- .../vdb/elasticsearch/test_elasticsearch.py | 11 +- .../vdb/milvus/test_milvus.py | 10 +- .../vdb/myscale/test_myscale.py | 2 +- .../vdb/opensearch/test_opensearch.py | 109 ++-- .../vdb/pgvecto_rs/test_pgvecto_rs.py | 13 +- .../vdb/qdrant/test_qdrant.py | 8 +- .../vdb/tcvectordb/test_tencent.py | 28 +- .../vdb/test_vector_store.py | 14 +- .../vdb/tidb_vector/test_tidb_vector.py | 14 +- .../vdb/weaviate/test_weaviate.py | 8 +- .../workflow/nodes/__mock/code_executor.py | 15 +- .../workflow/nodes/__mock/http.py | 30 +- .../nodes/code_executor/test_code_executor.py | 6 +- .../code_executor/test_code_javascript.py | 15 +- .../nodes/code_executor/test_code_jinja2.py | 25 +- .../nodes/code_executor/test_code_python3.py | 11 +- .../workflow/nodes/test_code.py | 278 +++++----- .../workflow/nodes/test_http.py | 487 ++++++++--------- .../workflow/nodes/test_llm.py | 244 ++++----- .../nodes/test_parameter_extractor.py | 499 +++++++++--------- .../workflow/nodes/test_template_transform.py | 45 +- .../workflow/nodes/test_tool.py | 101 ++-- .../unit_tests/configs/test_dify_config.py | 57 +- .../core/app/segments/test_factory.py | 98 ++-- .../core/app/segments/test_segment.py | 26 +- .../core/app/segments/test_variables.py | 50 +- .../unit_tests/core/helper/test_ssrf_proxy.py | 14 +- .../wenxin/test_text_embedding.py | 20 +- .../prompt/test_advanced_prompt_transform.py | 114 ++-- .../test_agent_history_prompt_transform.py | 38 +- .../core/prompt/test_prompt_transform.py | 26 +- .../prompt/test_simple_prompt_transform.py | 133 +++-- .../rag/datasource/vdb/milvus/test_milvus.py | 11 +- .../rag/extractor/firecrawl/test_firecrawl.py | 12 +- .../rag/extractor/test_notion_extractor.py | 57 +- .../unit_tests/core/test_model_manager.py | 35 +- .../unit_tests/core/test_provider_manager.py | 180 +++---- .../tools/test_tool_parameter_converter.py | 48 +- .../core/workflow/nodes/test_answer.py | 37 +- .../core/workflow/nodes/test_if_else.py | 248 ++++----- .../workflow/nodes/test_variable_assigner.py | 98 ++-- api/tests/unit_tests/libs/test_pandas.py | 40 +- api/tests/unit_tests/libs/test_rsa.py | 2 +- api/tests/unit_tests/libs/test_yarl.py | 26 +- api/tests/unit_tests/models/test_account.py | 10 +- .../models/test_conversation_variable.py | 14 +- api/tests/unit_tests/models/test_workflow.py | 94 ++-- .../workflow/test_workflow_converter.py | 187 +++---- .../position_helper/test_position_helper.py | 66 +-- .../unit_tests/utils/yaml/test_yaml_utils.py | 38 +- 155 files changed, 4272 insertions(+), 5918 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 6175fdbda71040..e05a51dc135ca7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -76,7 +76,6 @@ exclude = [ "migrations/**/*", "services/**/*.py", "tasks/**/*.py", - "tests/**/*.py", ] [tool.pytest_env] diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 3326f874b09505..79a3dc03941c9f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -22,23 +22,20 @@ ) from anthropic.types.message_delta_event import Delta -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockAnthropicClass: @staticmethod def mocked_anthropic_chat_create_sync(model: str) -> Message: return Message( - id='msg-123', - type='message', - role='assistant', - content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')], + id="msg-123", + type="message", + role="assistant", + content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")], model=model, - stop_reason='stop_sequence', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) + stop_reason="stop_sequence", + usage=Usage(input_tokens=1, output_tokens=1), ) @staticmethod @@ -46,52 +43,43 @@ def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent full_response_text = "hello, I'm a chatbot from anthropic" yield MessageStartEvent( - type='message_start', + type="message_start", message=Message( - id='msg-123', + id="msg-123", content=[], - role='assistant', + role="assistant", model=model, stop_reason=None, - type='message', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) - ) + type="message", + usage=Usage(input_tokens=1, output_tokens=1), + ), ) index = 0 for i in range(0, len(full_response_text)): yield ContentBlockDeltaEvent( - type='content_block_delta', - delta=TextDelta(text=full_response_text[i], type='text_delta'), - index=index + type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index ) index += 1 yield MessageDeltaEvent( - type='message_delta', - delta=Delta( - stop_reason='stop_sequence' - ), - usage=MessageDeltaUsage( - output_tokens=1 - ) + type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1) ) - yield MessageStopEvent(type='message_stop') - - def mocked_anthropic(self: Messages, *, - max_tokens: int, - messages: Iterable[MessageParam], - model: str, - stream: Literal[True], - **kwargs: Any - ) -> Union[Message, Stream[MessageStreamEvent]]: + yield MessageStopEvent(type="message_stop") + + def mocked_anthropic( + self: Messages, + *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any, + ) -> Union[Message, Stream[MessageStreamEvent]]: if len(self._client.api_key) < 18: - raise anthropic.AuthenticationError('Invalid API key') + raise anthropic.AuthenticationError("Invalid API key") if stream: return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model) @@ -102,7 +90,7 @@ def mocked_anthropic(self: Messages, *, @pytest.fixture def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic) + monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic) yield diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index d838e9890ff562..bc0684086f0b20 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -12,63 +12,46 @@ from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse -current_api_key = '' +current_api_key = "" + class MockGoogleResponseClass: _done = False def __iter__(self): - full_response_text = 'it\'s google!' + full_response_text = "it's google!" for i in range(0, len(full_response_text) + 1, 1): if i == len(full_response_text): self._done = True yield GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) else: yield GenerateContentResponse( - done=False, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) + class MockGoogleResponseCandidateClass: - finish_reason = 'stop' + finish_reason = "stop" @property def content(self) -> gag_content.Content: - return gag_content.Content( - parts=[ - gag_content.Part(text='it\'s google!') - ] - ) + return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) + class MockGoogleClass: @staticmethod def generate_content_sync() -> GenerateContentResponse: - return GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] - ) + return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: return MockGoogleResponseClass() - def generate_content(self: GenerativeModel, + def generate_content( + self: GenerativeModel, contents: content_types.ContentsType, *, generation_config: generation_config_types.GenerationConfigType | None = None, @@ -79,21 +62,21 @@ def generate_content(self: GenerativeModel, global current_api_key if len(current_api_key) < 16: - raise Exception('Invalid API key') + raise Exception("Invalid API key") if stream: return MockGoogleClass.generate_content_stream() - + return MockGoogleClass.generate_content_sync() - + @property def generative_response_text(self) -> str: - return 'it\'s google!' - + return "it's google!" + @property def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] - + def make_client(self: _ClientManager, name: str): global current_api_key @@ -121,7 +104,8 @@ def nop(self, *args, **kwargs): if not self.default_metadata: return client - + + @pytest.fixture def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) @@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): yield - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index a75b058d92c16a..97038ef5963e87 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -6,14 +6,15 @@ from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): if MOCK: monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation) - + yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 1607624c3c9056..9ee76c935c9873 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -22,10 +22,8 @@ def generate_create_sync(model: str) -> TextGenerationResponse: details=Details( finish_reason="length", generated_tokens=6, - tokens=[ - Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6) - ] - ) + tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)], + ), ) return response @@ -36,26 +34,23 @@ def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse for i in range(0, len(full_text)): response = TextGenerationStreamResponse( - token = Token(id=i, text=full_text[i], logprob=0.0, special=False), + token=Token(id=i, text=full_text[i], logprob=0.0, special=False), ) response.generated_text = full_text[i] - response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1) + response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1) yield response - def text_generation(self: InferenceClient, prompt: str, *, - stream: Literal[False] = ..., - model: Optional[str] = None, - **kwargs: Any + def text_generation( + self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]: # check if key is valid - if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']): - raise BadRequestError('Invalid API key') - + if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]): + raise BadRequestError("Invalid API key") + if model is None: - raise BadRequestError('Invalid model') - + raise BadRequestError("Invalid model") + if stream: return MockHuggingfaceChatClass.generate_create_stream(model) return MockHuggingfaceChatClass.generate_create_sync(model) - diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index c2fe95974b10f1..b37b109ebae58c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -5,10 +5,10 @@ class MockTEIClass: @staticmethod def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: # During mock, we don't have a real server to query, so we just return a dummy value - if 'rerank' in model_name: - model_type = 'reranker' + if "rerank" in model_name: + model_type = "reranker" else: - model_type = 'embedding' + model_type = "embedding" return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) @@ -17,16 +17,16 @@ def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: # Use space as token separator, and split the text into tokens tokenized_texts = [] for text in texts: - tokens = text.split(' ') + tokens = text.split(" ") current_index = 0 tokenized_text = [] for idx, token in enumerate(tokens): s_token = { - 'id': idx, - 'text': token, - 'special': False, - 'start': current_index, - 'stop': current_index + len(token), + "id": idx, + "text": token, + "special": False, + "start": current_index, + "stop": current_index + len(token), } current_index += len(token) + 1 tokenized_text.append(s_token) @@ -55,18 +55,18 @@ def invoke_embeddings(server_url: str, texts: list[str]) -> dict: embedding = [0.1] * 768 embeddings.append( { - 'object': 'embedding', - 'embedding': embedding, - 'index': idx, + "object": "embedding", + "embedding": embedding, + "index": idx, } ) return { - 'object': 'list', - 'data': embeddings, - 'model': 'MODEL_NAME', - 'usage': { - 'prompt_tokens': sum(len(text.split(' ')) for text in texts), - 'total_tokens': sum(len(text.split(' ')) for text in texts), + "object": "list", + "data": embeddings, + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": sum(len(text.split(" ")) for text in texts), + "total_tokens": sum(len(text.split(" ")) for text in texts), }, } @@ -83,9 +83,9 @@ def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: for idx, text in enumerate(texts): reranked_docs.append( { - 'index': idx, - 'text': text, - 'score': 0.9, + "index": idx, + "text": text, + "score": 0.9, } ) # For mock, only return the first document diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index 0d3f0fbbeaab5f..6637f4f212a50e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: +def mock_openai( + monkeypatch: MonkeyPatch, + methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]], +) -> Callable[[], None]: """ - mock openai module + mock openai module - :param monkeypatch: pytest monkeypatch fixture - :return: unpatch function + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function """ + def unpatch() -> None: monkeypatch.undo() @@ -52,15 +56,16 @@ def unpatch() -> None: return unpatch -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_openai_mock(request, monkeypatch): - methods = request.param if hasattr(request, 'param') else [] + methods = request.param if hasattr(request, "param") else [] if MOCK: unpatch = mock_openai(monkeypatch, methods=methods) - + yield if MOCK: - unpatch() \ No newline at end of file + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index ba902e32ea6b22..d9cd7b046e001c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -43,62 +43,64 @@ def generate_function_call( if not functions or len(functions) == 0: return None function: completion_create_params.Function = functions[0] - function_name = function['name'] - function_description = function['description'] - function_parameters = function['parameters'] - function_parameters_type = function_parameters['type'] - if function_parameters_type != 'object': + function_name = function["name"] + function_description = function["description"] + function_parameters = function["parameters"] + function_parameters_type = function_parameters["type"] + if function_parameters_type != "object": return None - function_parameters_properties = function_parameters['properties'] - function_parameters_required = function_parameters['required'] + function_parameters_properties = function_parameters["properties"] + function_parameters_required = function_parameters["required"] parameters = {} for parameter_name, parameter in function_parameters_properties.items(): if parameter_name not in function_parameters_required: continue - parameter_type = parameter['type'] - if parameter_type == 'string': - if 'enum' in parameter: - if len(parameter['enum']) == 0: + parameter_type = parameter["type"] + if parameter_type == "string": + if "enum" in parameter: + if len(parameter["enum"]) == 0: continue - parameters[parameter_name] = parameter['enum'][0] + parameters[parameter_name] = parameter["enum"][0] else: - parameters[parameter_name] = 'kawaii' - elif parameter_type == 'integer': + parameters[parameter_name] = "kawaii" + elif parameter_type == "integer": parameters[parameter_name] = 114514 - elif parameter_type == 'number': + elif parameter_type == "number": parameters[parameter_name] = 1919810.0 - elif parameter_type == 'boolean': + elif parameter_type == "boolean": parameters[parameter_name] = True return FunctionCall(name=function_name, arguments=dumps(parameters)) - + @staticmethod - def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: + def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None tool = tools[0] - if 'type' in tools and tools['type'] != 'function': + if "type" in tools and tools["type"] != "function": return None - function = tool['function'] + function = tool["function"] function_call = MockChatClass.generate_function_call(functions=[function]) if function_call is None: return None - - list_tool_calls.append(ChatCompletionMessageToolCall( - id='sakurajima-mai', - function=Function( - name=function_call.name, - arguments=function_call.arguments, - ), - type='function' - )) + + list_tool_calls.append( + ChatCompletionMessageToolCall( + id="sakurajima-mai", + function=Function( + name=function_call.name, + arguments=function_call.arguments, + ), + type="function", + ) + ) return list_tool_calls - + @staticmethod def mocked_openai_chat_create_sync( model: str, @@ -111,30 +113,27 @@ def mocked_openai_chat_create_sync( tool_calls = MockChatClass.generate_tool_calls(tools=tools) return _ChatCompletion( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ _ChatCompletionChoice( - finish_reason='content_filter', + finish_reason="content_filter", index=0, message=ChatCompletionMessage( - content='elaina', - role='assistant', - function_call=function_call, - tool_calls=tool_calls - ) + content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls + ), ) ], created=int(time()), model=model, - object='chat.completion', - system_fingerprint='', + object="chat.completion", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod def mocked_openai_chat_create_stream( model: str, @@ -150,36 +149,40 @@ def mocked_openai_chat_create_stream( for i in range(0, len(full_text) + 1): if i == len(full_text): yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( - content='', + content="", function_call=ChoiceDeltaFunctionCall( name=function_call.name, arguments=function_call.arguments, - ) if function_call else None, - role='assistant', + ) + if function_call + else None, + role="assistant", tool_calls=[ ChoiceDeltaToolCall( index=0, - id='misaka-mikoto', + id="misaka-mikoto", function=ChoiceDeltaToolCallFunction( name=tool_calls[0].function.name, arguments=tool_calls[0].function.arguments, ), - type='function' + type="function", ) - ] if tool_calls and len(tool_calls) > 0 else None + ] + if tool_calls and len(tool_calls) > 0 + else None, ), - finish_reason='function_call', + finish_reason="function_call", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=17, @@ -188,30 +191,45 @@ def mocked_openai_chat_create_stream( ) else: yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( content=full_text[i], - role='assistant', + role="assistant", ), - finish_reason='content_filter', + finish_reason="content_filter", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", ) - def chat_create(self: Completions, *, + def chat_create( + self: Completions, + *, messages: list[ChatCompletionMessageParam], - model: Union[str,Literal[ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"], + model: Union[ + str, + Literal[ + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + ], ], functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, @@ -220,24 +238,32 @@ def chat_create(self: Completions, *, **kwargs: Any, ): openai_models = [ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", - ] - azure_openai_models = [ - "gpt35", "gpt-4v", "gpt-35-turbo" + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", ] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if stream: return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools) - - return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) \ No newline at end of file + + return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index b0d26759055bff..c27e89248f4403 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -17,9 +17,7 @@ class MockCompletionsClass: @staticmethod - def mocked_openai_completion_create_sync( - model: str - ) -> CompletionMessage: + def mocked_openai_completion_create_sync(model: str) -> CompletionMessage: return CompletionMessage( id="cmpl-3QJQa5jXJ5Z5X", object="text_completion", @@ -38,13 +36,11 @@ def mocked_openai_completion_create_sync( prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod - def mocked_openai_completion_create_stream( - model: str - ) -> Generator[CompletionMessage, None, None]: + def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]: full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```" for i in range(0, len(full_text) + 1): if i == len(full_text): @@ -76,46 +72,59 @@ def mocked_openai_completion_create_stream( model=model, system_fingerprint="", choices=[ - CompletionChoice( - text=full_text[i], - index=0, - logprobs=None, - finish_reason="content_filter" - ) + CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter") ], ) - def completion_create(self: Completions, *, model: Union[ - str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", - "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", - "text-ada-001"], + def completion_create( + self: Completions, + *, + model: Union[ + str, + Literal[ + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + ], ], prompt: Union[str, list[str], list[int], list[list[int]], None], stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ): openai_models = [ - "babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", - ] - azure_openai_models = [ - "gpt-35-turbo-instruct" + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", ] + azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') - + raise InvokeAuthorizationError("Invalid api key") + if not prompt: - raise BadRequestError('Invalid prompt') + raise BadRequestError("Invalid prompt") if stream: return MockCompletionsClass.mocked_openai_completion_create_stream(model=model) - - return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) \ No newline at end of file + + return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index eccdbd34795c6e..4138cdd40d822c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -12,48 +12,39 @@ class MockEmbeddingsClass: def create_embeddings( - self: Embeddings, *, + self: Embeddings, + *, input: Union[str, list[str], list[int], list[list[int]]], model: Union[str, Literal["text-embedding-ada-002"]], encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> CreateEmbeddingResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - if encoding_format == 'float': + raise InvokeAuthorizationError("Invalid API key") + + if encoding_format == "float": return CreateEmbeddingResponse( data=[ - Embedding( - embedding=[0.23333 for _ in range(233)], - index=i, - object='embedding' - ) for i in range(len(input)) + Embedding(embedding=[0.23333 for _ in range(233)], index=i, object="embedding") + for i in range(len(input)) ], model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) + usage=Usage(prompt_tokens=2, total_tokens=2), ) - - embeddings = 'VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7' + + embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" data = [] for i, text in enumerate(input): - obj = Embedding( - embedding=[], - index=i, - object='embedding' - ) + obj = Embedding(embedding=[], index=i, object="embedding") obj.embedding = embeddings data.append(obj) @@ -61,10 +52,7 @@ def create_embeddings( return CreateEmbeddingResponse( data=data, model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) - ) \ No newline at end of file + usage=Usage(prompt_tokens=2, total_tokens=2), + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 9466f4bfb8e794..270a88e85ffedd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -10,58 +10,92 @@ class MockModerationClass: - def moderation_create(self: Moderations,*, + def moderation_create( + self: Moderations, + *, input: Union[str, list[str]], model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> ModerationCreateResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') + raise InvokeAuthorizationError("Invalid API key") for text in input: result = [] - if 'kill' in text: + if "kill" in text: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0, - 'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0, - 'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0 + "harassment": 1.0, + "harassment/threatening": 1.0, + "hate": 1.0, + "hate/threatening": 1.0, + "self-harm": 1.0, + "self-harm/instructions": 1.0, + "self-harm/intent": 1.0, + "sexual": 1.0, + "sexual/minors": 1.0, + "violence": 1.0, + "violence/graphic": 1.0, } - result.append(Moderation( - flagged=True, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=True, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) else: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0, - 'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0, - 'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0 + "harassment": 0.0, + "harassment/threatening": 0.0, + "hate": 0.0, + "hate/threatening": 0.0, + "self-harm": 0.0, + "self-harm/instructions": 0.0, + "self-harm/intent": 0.0, + "sexual": 0.0, + "sexual/minors": 0.0, + "violence": 0.0, + "violence/graphic": 0.0, } - result.append(Moderation( - flagged=False, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=False, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) - return ModerationCreateResponse( - id='shiroii kuloko', - model=model, - results=result - ) \ No newline at end of file + return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 0124ac045b6877..cb8f2495438783 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -6,17 +6,18 @@ class MockModelClass: """ - mock class for openai.models.Models + mock class for openai.models.Models """ + def list( self, **kwargs, ) -> list[Model]: return [ Model( - id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ', + id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ", created=int(time()), - object='model', - owned_by='organization:org-123', + object="model", + owned_by="organization:org-123", ) - ] \ No newline at end of file + ] diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index 755fec4c1fbc9e..ef361e86139427 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -9,7 +9,8 @@ class MockSpeech2TextClass: - def speech2text_create(self: Transcriptions, + def speech2text_create( + self: Transcriptions, *, file: FileTypes, model: Union[str, Literal["whisper-1"]], @@ -17,14 +18,12 @@ def speech2text_create(self: Transcriptions, prompt: str | NotGiven = NOT_GIVEN, response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> Transcription: - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - return Transcription( - text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10' - ) \ No newline at end of file + raise InvokeAuthorizationError("Invalid API key") + + return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10") diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 7cb0a1318e9b8d..777737187e259c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -19,40 +19,43 @@ class MockXinferenceClass: - def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: - if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): - raise RuntimeError('404 Not Found') - - if 'generate' == model_uid: + def get_chat_model( + self: Client, model_uid: str + ) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: + if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): + raise RuntimeError("404 Not Found") + + if "generate" == model_uid: return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'chat' == model_uid: + if "chat" == model_uid: return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'embedding' == model_uid: + if "embedding" == model_uid: return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'rerank' == model_uid: + if "rerank" == model_uid: return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - raise RuntimeError('404 Not Found') - + raise RuntimeError("404 Not Found") + def get(self: Session, url: str, **kwargs): response = Response() - if 'v1/models/' in url: + if "v1/models/" in url: # get model uid - model_uid = url.split('/')[-1] or '' - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ - model_uid not in ['generate', 'chat', 'embedding', 'rerank']: + model_uid = url.split("/")[-1] or "" + if not re.match( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid + ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response # check if url is valid - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response - - if model_uid in ['generate', 'chat']: + + if model_uid in ["generate", "chat"]: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "LLM", "address": "127.0.0.1:43877", "accelerators": [ @@ -75,12 +78,12 @@ def get(self: Session, url: str, **kwargs): "revision": null, "context_length": 2048, "replica": 1 - }''' + }""" return response - - elif model_uid == 'embedding': + + elif model_uid == "embedding": response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "embedding", "address": "127.0.0.1:43877", "accelerators": [ @@ -93,51 +96,48 @@ def get(self: Session, url: str, **kwargs): ], "revision": null, "max_tokens": 512 - }''' + }""" return response - - elif 'v1/cluster/auth' in url: + + elif "v1/cluster/auth" in url: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "auth": true - }''' + }""" return response - + def _check_cluster_authenticated(self): self._cluster_authed = True - - def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict: + + def rerank( + self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool + ) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'rerank': - raise RuntimeError('404 Not Found') - - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url): - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "rerank" + ): + raise RuntimeError("404 Not Found") + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): + raise RuntimeError("404 Not Found") if top_n is None: top_n = 1 return { - 'results': [ - { - 'index': i, - 'document': doc, - 'relevance_score': 0.9 - } - for i, doc in enumerate(documents[:top_n]) + "results": [ + {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) ] } - - def create_embedding( - self: RESTfulGenerateModelHandle, - input: Union[str, list[str]], - **kwargs - ) -> dict: + + def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'embedding': - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "embedding" + ): + raise RuntimeError("404 Not Found") if isinstance(input, str): input = [input] @@ -147,32 +147,27 @@ def create_embedding( object="list", model=self._model_uid, data=[ - EmbeddingData( - index=i, - object="embedding", - embedding=[1919.810 for _ in range(768)] - ) + EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) for i in range(ipt_len) ], - usage=EmbeddingUsage( - prompt_tokens=ipt_len, - total_tokens=ipt_len - ) + usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), ) return embedding -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_xinference_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) - monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) - monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) - monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) - monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) + monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) + monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) + monkeypatch.setattr(Session, "get", MockXinferenceClass.get) + monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) + monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index 0d54d97daad24e..8f7e9ec48743bf 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -10,79 +10,60 @@ from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_credentials(setup_anthropic_mock): model = AnthropicLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")} ) -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', + model="claude-instant-1.2", credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), - 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') + "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"), + "anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,18 +79,14 @@ def test_get_num_tokens(): model = AnthropicLargeLanguageModel() num_tokens = model.get_num_tokens( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 7eaa40dfddc2af..6f1e50f431849f 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_anthropic_mock): provider = AnthropicProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index 6afec540ade181..8f50ebf7a6d03f 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -17,101 +17,90 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo", + }, ) model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo-instruct", + }, ) model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -122,66 +111,60 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -194,109 +177,87 @@ def test_invoke_stream_chat_model(setup_openai_mock): assert chunk.delta.usage is not None assert chunk.delta.usage.completion_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-4v', + model="gpt-4v", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-4-vision-preview' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-4-vision-preview", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo', + model="gpt-35-turbo", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -308,32 +269,22 @@ def test_get_num_tokens(): model = AzureOpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-35-turbo-instruct', - credentials={ - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-35-turbo-instruct", + credentials={"base_model_name": "gpt-35-turbo-instruct"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt35', - credentials={ - 'base_model_name': 'gpt-35-turbo' - }, + model="gpt35", + credentials={"base_model_name": "gpt-35-turbo"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 8b838eb8fc8183..a1ae2b2e5b740c 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -8,45 +8,43 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "text-embedding-ada-002", + }, ) model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() result = model.invoke( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -58,14 +56,7 @@ def test_get_num_tokens(): model = AzureOpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embedding', - credentials={ - 'base_model_name': 'text-embedding-ada-002' - }, - texts=[ - "hello", - "world" - ] + model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index 1cae9a6dd0962e..ad586102879a10 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -17,111 +17,99 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = BaichuanLarguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='baichuan2-turbo', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') - } + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, ) + def test_invoke_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_with_system_message(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, prompt_messages=[ - SystemPromptMessage( - content='请记住你是Kasumi。' - ), - UserPromptMessage( - content='现在告诉我你是谁?' - ) + SystemPromptMessage(content="请记住你是Kasumi。"), + UserPromptMessage(content="现在告诉我你是谁?"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -131,34 +119,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'with_search_enhance': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "with_search_enhance": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -166,25 +151,22 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True total_message += chunk.delta.message.content - assert '不' not in total_message + assert "不" not in total_message + def test_get_num_tokens(): sleep(3) model = BaichuanLarguageModel() response = model.get_num_tokens( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 9 \ No newline at end of file + assert response == 9 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index 87b3d9a6099839..4036edfb7a7062 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = BaichuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index 1210ebc53d96f8..cbc63f3978fb99 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = BaichuanTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } + model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")} ) @@ -30,44 +22,40 @@ def test_invoke_model(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = BaichuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, texts=[ "hello", @@ -92,8 +80,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py index 20dc11151a70a7..c19ec35a6e45fc 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -13,77 +13,63 @@ def test_validate_credentials(): model = BedrockLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='meta.llama2-13b-chat-v1', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") - } + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, ) + def test_invoke_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens_to_sample': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens_to_sample': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,20 +86,18 @@ def test_get_num_tokens(): model = BedrockLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta.llama2-13b-chat-v1', - credentials = { + model="meta.llama2-13b-chat-v1", + credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py index e53d4c1db2133b..080727829e9e2f 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = BedrockProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index e32f01a315b7d5..418e88874d1572 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -23,79 +23,64 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': 'invalid_key' - } - ) - - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"}) + + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) + -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。' + content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。" ), - UserPromptMessage( - content='波士顿天气如何?' - ) + UserPromptMessage(content="波士顿天气如何?"), ], model_parameters={ - 'temperature': 0, - 'top_p': 1.0, + "temperature": 0, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=True, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, Generator) - + call: LLMResultChunk = None chunks = [] @@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock): break assert call is not None - assert call.delta.message.tool_calls[0].function.name == 'get_current_weather' + assert call.delta.message.tool_calls[0].function.name == "get_current_weather" -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, - prompt_messages=[ - UserPromptMessage( - content='What is the weather like in San Francisco?' - ) - ], + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=False, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 - assert response.message.tool_calls[0].function.name == 'get_current_weather' + assert response.message.tool_calls[0].function.name == "get_current_weather" def test_get_num_tokens(): model = ChatGLMLargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index e9c5c4da751b75..7907805d072772 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -7,19 +7,11 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = ChatGLMProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_base': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_base": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py index 5ce4f8ecfe874f..b7f707e935dbea 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_llm.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_validate_credentials_for_completion_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_completion_model(): model = CohereLargeLanguageModel() - credentials = { - 'api_key': os.environ.get('COHERE_API_KEY') - } + credentials = {"api_key": os.environ.get("COHERE_API_KEY")} result = model.invoke( - model='command-light', + model="command-light", credentials=credentials, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 + assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1 def test_invoke_stream_completion_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -109,28 +71,24 @@ def test_invoke_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'p': 0.99, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "p": 0.99, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -141,24 +99,17 @@ def test_invoke_stream_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -177,32 +128,22 @@ def test_get_num_tokens(): model = CohereLargeLanguageModel() num_tokens = model.get_num_tokens( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 15 @@ -213,25 +154,17 @@ def test_fine_tuned_model(): # test invoke result = model.invoke( - model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'completion' - }, + model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -242,25 +175,17 @@ def test_fine_tuned_chat_model(): # test invoke result = model.invoke( - model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'chat' - }, + model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index a8f56b61943c8b..fb7e6d34984a61 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = CohereProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index 415c5fbfda56d0..a1b6922128570e 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -11,29 +11,17 @@ def test_validate_credentials(): model = CohereRerankModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_model(): model = CohereRerankModel() result = model.invoke( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="rerank-english-v2.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, query="What is the capital of the United States?", docs=[ "Carson City is the capital city of the American state of Nevada. At the 2010 United States " @@ -41,9 +29,9 @@ def test_invoke_model(): "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) " "is the capital of the United States. It is a federal district. The President of the USA and many major " "national government offices are in the territory. This makes it the political center of the United " - "States of America." + "States of America.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py index 5017ba47e11033..ae26d36635d1b5 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = CohereTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } + model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")} ) @@ -30,17 +22,10 @@ def test_invoke_model(): model = CohereTextEmbeddingModel() result = model.invoke( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -52,14 +37,9 @@ def test_get_num_tokens(): model = CohereTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world"], ) assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 00d907d19ef7ef..4d9d490a8720c2 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -16,103 +16,73 @@ from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_credentials(setup_google_mock): model = GoogleLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": "invalid_key"}) + + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) + -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.5, - 'top_p': 1.0, - 'max_tokens_to_sample': 2048 - }, - stop=['How'], + model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.2, - 'top_k': 5, - 'max_tokens_to_sample': 2048 - }, + model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,88 +93,66 @@ def test_invoke_stream_model(setup_google_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), + content=[ + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.' - ), + SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), + content=[ + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ), - AssistantPromptMessage( - content="I see a blue letter 'D' with a gradient from light blue to dark blue." - ), + AssistantPromptMessage(content="I see a blue letter 'D' with a gradient from light blue to dark blue."), UserPromptMessage( content=[ - TextPromptMessageContent( - data="what about now?" - ), + TextPromptMessageContent(data="what about now?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) print(f"resultz: {result.message.content}") @@ -212,23 +160,18 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): assert len(result.message.content) > 0 - def test_get_num_tokens(): model = GoogleLargeLanguageModel() num_tokens = model.get_num_tokens( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 103107ed5ae6c5..c217e4fe058870 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_google_mock): provider = GoogleProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index 28cd0955b33109..6a6cc874fa2f30 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -10,87 +10,75 @@ from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="HuggingFaceH4/zephyr-7b-beta", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='fake-model', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="fake-model", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -286,18 +264,14 @@ def test_get_num_tokens(): model = HuggingfaceHubLargeLanguageModel() num_tokens = model.get_num_tokens( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index d03b3186cb4657..0ee593f38a494a 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key', - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": "invalid_key", + }, ) model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) @@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) @@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -104,18 +98,15 @@ def test_get_num_tokens(): model = HuggingfaceHubTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index ed371fbc07aa8d..b1fa9d5ca5097f 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -10,61 +10,59 @@ ) from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK: - monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) - monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) - monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) - monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) yield if MOCK: monkeypatch.undo() -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_validate_credentials(setup_tei_mock): model = HuggingfaceTeiTextEmbeddingModel() # model name is only used in mock - model_name = 'embedding' + model_name = "embedding" if MOCK: # TEI Provider will check model type by API endpoint, at real server, the model type is correct. # So we dont need to check model type here. Only check in mock with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='reranker', + model="reranker", credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), - } + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, ) model.validate_credentials( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), - } + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, ) -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_invoke_model(setup_tei_mock): model = HuggingfaceTeiTextEmbeddingModel() - model_name = 'embedding' + model_name = "embedding" result = model.invoke( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py index 57e229e6be94e9..45370d9fba41b0 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK: - monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) - monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) - monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) - monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) yield if MOCK: monkeypatch.undo() -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_validate_credentials(setup_tei_mock): model = HuggingfaceTeiRerankModel() # model name is only used in mock - model_name = 'reranker' + model_name = "reranker" if MOCK: # TEI Provider will check model type by API endpoint, at real server, the model type is correct. # So we dont need to check model type here. Only check in mock with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), - } + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, ) model.validate_credentials( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), - } + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, ) -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_invoke_model(setup_tei_mock): model = HuggingfaceTeiRerankModel() # model name is only used in mock - model_name = 'reranker' + model_name = "reranker" result = model.invoke( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py index 305f967ef0a785..b3049a06d9b98a 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py @@ -14,19 +14,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-standard', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -34,23 +30,16 @@ def test_invoke_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -61,23 +50,15 @@ def test_invoke_stream_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,19 +74,17 @@ def test_get_num_tokens(): model = HunyuanLargeLanguageModel() num_tokens = model.get_num_tokens( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py index bdec3d0e22d6d5..e3748c2ce713d4 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py @@ -10,16 +10,11 @@ def test_validate_provider_credentials(): provider = HunyuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } - ) + provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}) provider.validate_provider_credentials( credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py index 7ae6c0e45635db..69d14dffeebf35 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -12,19 +12,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-embedding', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -32,47 +28,43 @@ def test_invoke_model(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = HunyuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, texts=[ "hello", @@ -97,8 +89,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 2b43248388e845..e3b6128c59d8df 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = JinaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index ac175661746a17..290735ec49e625 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = JinaTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } + model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")} ) @@ -30,15 +22,12 @@ def test_invoke_model(): model = JinaTextEmbeddingModel() result = model.invoke( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +39,11 @@ def test_get_num_tokens(): model = JinaTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/localai/test_embedding.py b/api/tests/integration_tests/model_runtime/localai/test_embedding.py index e05345ee56e67d..7fd9f2b3000a31 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/localai/test_embedding.py @@ -1,4 +1,4 @@ """ - LocalAI Embedding Interface is temporarily unavailable due to - we could not find a way to test it for now. -""" \ No newline at end of file +LocalAI Embedding Interface is temporarily unavailable due to +we could not find a way to test it for now. +""" diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index 6f421403d4d517..aa5436c34fc7dd 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,28 +102,21 @@ def test_invoke_stream_completion_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_stream_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -154,64 +126,48 @@ def test_invoke_stream_chat_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = LocalAILanguageModel() num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py index 99847bc8528a0a..13c7df6d1473b0 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-v2-m3', + model="bge-reranker-v2-m3", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_rerank_model(): model = LocalaiRerankModel() response = model.invoke( - model='bge-reranker-base', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -45,43 +44,38 @@ def test_invoke_rerank_model(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(response, RerankResult) assert len(response.docs) == 3 + def test__invoke(): model = LocalaiRerankModel() # Test case 1: Empty docs result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 0 # Test case 2: Valid invocation result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -91,12 +85,12 @@ def test__invoke(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 3 - assert all(isinstance(doc, RerankDocument) for doc in result.docs) \ No newline at end of file + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py index 3fd2ebed4f0be1..91b7a5752ce973 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -10,19 +10,9 @@ def test_validate_credentials(): model = LocalAISpeech2text() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': 'invalid_url' - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}) def test_invoke_model(): @@ -32,23 +22,21 @@ def test_invoke_model(): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, + model="whisper-1", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' \ No newline at end of file + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index 6f4b8a163f96da..cf2a28eb9eb2fe 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -12,54 +12,47 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embo-01', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + model="embo-01", + credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")}, ) model.validate_credentials( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): model = MinimaxTextEmbeddingModel() result = model.invoke( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 16 + def test_get_num_tokens(): model = MinimaxTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index 570e4901a9e7e2..aacde04d326caf 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -17,79 +17,70 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = MinimaxLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='abab5.5-chat', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': 'invalid_key' - } + model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"} ) model.validate_credentials( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5-chat', + model="abab5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -99,34 +90,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'plugin_web_search': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -134,25 +122,22 @@ def test_invoke_with_search(): total_message += chunk.delta.message.content assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True - assert '参考资料' in total_message + assert "参考资料" in total_message + def test_get_num_tokens(): sleep(3) model = MinimaxLargeLanguageModel() response = model.get_num_tokens( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 30 \ No newline at end of file + assert response == 30 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 4c5462c6dff551..575ed13eef124a 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -12,14 +12,14 @@ def test_validate_provider_credentials(): with pytest.raises(CredentialsValidateFailedError): provider.validate_provider_credentials( credentials={ - 'minimax_api_key': 'hahahaha', - 'minimax_group_id': '123', + "minimax_api_key": "hahahaha", + "minimax_group_id": "123", } ) provider.validate_provider_credentials( credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'), + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), } ) diff --git a/api/tests/integration_tests/model_runtime/novita/test_llm.py b/api/tests/integration_tests/model_runtime/novita/test_llm.py index 4ebc68493f26d5..35fa0dc1904f7b 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_llm.py +++ b/api/tests/integration_tests/model_runtime/novita/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'completion' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_p': 0.5, - 'max_tokens': 10, + "temperature": 1.0, + "top_p": 0.5, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="novita" + user="novita", ) assert isinstance(response, LLMResult) @@ -70,27 +58,17 @@ def test_invoke_stream_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'max_tokens': 100 - }, + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100}, stream=True, - user="novita" + user="novita", ) assert isinstance(response, Generator) @@ -105,18 +83,16 @@ def test_get_num_tokens(): model = NovitaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta-llama/llama-3-8b-instruct', + model="meta-llama/llama-3-8b-instruct", credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/novita/test_provider.py b/api/tests/integration_tests/model_runtime/novita/test_provider.py index bb3f19dc851ea5..191af99db20bd9 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_provider.py +++ b/api/tests/integration_tests/model_runtime/novita/test_provider.py @@ -10,12 +10,10 @@ def test_validate_provider_credentials(): provider = NovitaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 272e639a8ac11e..58a1339f506458 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -20,23 +20,23 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) @@ -44,26 +44,17 @@ def test_invoke_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -74,29 +65,22 @@ def test_invoke_stream_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -111,26 +95,17 @@ def test_invoke_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -141,29 +116,22 @@ def test_invoke_stream_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -178,29 +146,26 @@ def test_invoke_completion_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -212,29 +177,26 @@ def test_invoke_chat_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -246,18 +208,14 @@ def test_get_num_tokens(): model = OllamaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py index c5f5918235d3a2..3c4f740a4fd09c 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -12,21 +12,21 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 4096, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, ) @@ -34,17 +34,14 @@ def test_invoke_model(): model = OllamaEmbeddingModel() result = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -56,16 +53,13 @@ def test_get_num_tokens(): model = OllamaEmbeddingModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index bf4ac53579fb6e..3b3ea9ec800cbb 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -28,92 +28,61 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": "invalid_key"}) - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-davinci-003", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-davinci-003", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1 + assert model._num_tokens_from_string("gpt-3.5-turbo-instruct", result.message.content) == 1 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', + model="gpt-3.5-turbo-instruct", credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_organization': os.environ.get('OPENAI_ORGANIZATION'), - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_key": os.environ.get("OPENAI_API_KEY"), + "openai_organization": os.environ.get("OPENAI_ORGANIZATION"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -124,166 +93,131 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-4-vision-preview', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-4-vision-preview", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -302,68 +236,46 @@ def test_get_num_tokens(): model = OpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 72 -@pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat", "remote"]], indirect=True) def test_fine_tuned_models(setup_openai_mock): model = OpenAILargeLanguageModel() - remote_models = model.remote_models(credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }) + remote_models = model.remote_models(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) if not remote_models: assert isinstance(remote_models, list) @@ -379,29 +291,23 @@ def test_fine_tuned_models(setup_openai_mock): # test invoke result = model.invoke( model=llm_model.model, - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) + def test__get_num_tokens_by_gpt2(): model = OpenAILargeLanguageModel() - num_tokens = model._get_num_tokens_by_gpt2('Hello World!') + num_tokens = model._get_num_tokens_by_gpt2("Hello World!") assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 04f9b9f33b3537..6de262471798ad 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -7,48 +7,37 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAIModerationModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAIModerationModel() result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="hello", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) assert result is False result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="i will kill you", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index 5314bffbdf37b1..4d56cfcf3c25f0 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = OpenAIProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index f1a5c4fd23f091..aa92c8b61fb684 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -7,26 +7,17 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAISpeech2TextModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"}) + + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAISpeech2TextModel() @@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="whisper-1", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index e2c4c74ee7636b..f5dd73f2d4cd60 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -8,42 +8,27 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAITextEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -55,15 +40,9 @@ def test_get_num_tokens(): model = OpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world" - ] + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index c8335085695fd5..f2302ef05e06de 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -23,21 +23,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"}, ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + }, ) @@ -45,28 +41,26 @@ def test_invoke_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'completion' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -77,29 +71,27 @@ def test_invoke_stream_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat', - 'stream_mode_delimiter': '\\n\\n' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + "stream_mode_delimiter": "\\n\\n", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools(): model = OAIAPICompatLargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', + model="gpt-3.5-turbo", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'mode': 'chat' + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1024 - }, + model_parameters={"temperature": 0.0, "max_tokens": 1024}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -207,19 +183,14 @@ def test_get_num_tokens(): model = OAIAPICompatLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py index 61079104dcad73..cf805eafff4968 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py @@ -14,18 +14,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="whisper-1", - credentials={ - "api_key": "invalid_key", - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"}, ) model.validate_credentials( model="whisper-1", - credentials={ - "api_key": os.environ.get("OPENAI_API_KEY"), - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, ) @@ -47,13 +41,10 @@ def test_invoke_model(): result = model.invoke( model="whisper-1", - credentials={ - "api_key": os.environ.get("OPENAI_API_KEY"), - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, file=file, user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 77d27ec1615fe0..052b41605f6da2 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -12,27 +12,23 @@ Using OpenAI's API as testing endpoint """ + def test_validate_credentials(): model = OAICompatEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - - } + model="text-embedding-ada-002", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184}, ) model.validate_credentials( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - } + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, + }, ) @@ -40,19 +36,14 @@ def test_invoke_model(): model = OAICompatEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -64,16 +55,13 @@ def test_get_num_tokens(): model = OAICompatEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/embeddings", + "context_size": 8184, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) - assert num_tokens == 2 \ No newline at end of file + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 9eb05a111d9340..14d47217af62c8 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -12,17 +12,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"), + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) @@ -30,33 +30,28 @@ def test_invoke_model(): model = OpenLLMTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = OpenLLMTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 853a0fbe3c9a16..35939e3cfe8bfd 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'invalid_key', - } + "server_url": "invalid_key", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) + def test_invoke_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -84,21 +78,18 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = OpenLLMLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 3 \ No newline at end of file + assert response == 3 diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py index 8f1fb4c4ad7990..ce4876a73a740e 100644 --- a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -70,27 +58,22 @@ def test_invoke_stream_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,18 +88,16 @@ def test_get_num_tokens(): model = OpenRouterLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/mixtral-8x7b-instruct', + model="mistralai/mixtral-8x7b-instruct", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index e248f064c05de3..b940005b715760 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -14,19 +14,19 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": "invalid_key", + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) @@ -34,27 +34,25 @@ def test_invoke_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,27 +63,25 @@ def test_invoke_stream_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct-v0.1', + model="mistralai/mixtral-8x7b-instruct-v0.1", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,19 +96,17 @@ def test_get_num_tokens(): model = ReplicateLargeLanguageModel() num_tokens = model.get_num_tokens( - model='', + model="", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 5708ec9e5a219e..397715f2252083 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -12,19 +12,19 @@ def test_validate_credentials_one(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": "invalid_key", + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) @@ -33,19 +33,19 @@ def test_validate_credentials_two(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": "invalid_key", + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) @@ -53,16 +53,13 @@ def test_invoke_model_one(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -74,16 +71,13 @@ def test_invoke_model_two(): model = ReplicateEmbeddingModel() result = model.invoke( - model='andreasjansson/clip-features', + model="andreasjansson/clip-features", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -95,16 +89,13 @@ def test_invoke_model_three(): model = ReplicateEmbeddingModel() result = model.invoke( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -116,16 +107,13 @@ def test_invoke_model_four(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -137,15 +125,12 @@ def test_get_num_tokens(): model = ReplicateEmbeddingModel() num_tokens = model.get_num_tokens( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py index 639227e7450343..9f0b439d6c32a1 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -10,10 +10,6 @@ def test_validate_provider_credentials(): provider = SageMakerProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py index c67849dd798883..d5a6798a1ef735 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -12,11 +12,11 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -25,7 +25,7 @@ def test_validate_credentials(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) @@ -33,11 +33,11 @@ def test_invoke_model(): model = SageMakerRerankModel() result = model.invoke( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -46,7 +46,7 @@ def test_invoke_model(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py index e817e8f04ab67c..e4e404c7a86ae6 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -11,45 +11,23 @@ def test_validate_credentials(): model = SageMakerEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='bge-m3', - credentials={ - } - ) + model.validate_credentials(model="bge-m3", credentials={}) - model.validate_credentials( - model='bge-m3-embedding', - credentials={ - } - ) + model.validate_credentials(model="bge-m3-embedding", credentials={}) def test_invoke_model(): model = SageMakerEmbeddingModel() - result = model.invoke( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - "hello", - "world" - ], - user="abc-123" - ) + result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123") assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 + def test_get_num_tokens(): model = SageMakerEmbeddingModel() - num_tokens = model.get_num_tokens( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - ] - ) + num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[]) assert num_tokens == 0 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py index befdd82352780e..f47c9c558808af 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = SiliconflowLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")}) def test_invoke_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = SiliconflowLargeLanguageModel() num_tokens = model.get_num_tokens( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py index 7b9211a5dbe9c7..8f70210b7a2ace 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = SiliconflowProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py index 7b3ff8272738a4..ad794613f91013 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -13,9 +13,7 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="BAAI/bge-reranker-v2-m3", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( @@ -30,17 +28,17 @@ def test_invoke_model(): model = SiliconflowRerankModel() result = model.invoke( - model='BAAI/bge-reranker-v2-m3', + model="BAAI/bge-reranker-v2-m3", credentials={ "api_key": os.environ.get("API_KEY"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py index 82b7921c8506f0..0502ba5ab404bc 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py @@ -12,16 +12,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="iic/SenseVoiceSmall", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( model="iic/SenseVoiceSmall", - credentials={ - "api_key": os.environ.get("API_KEY") - }, + credentials={"api_key": os.environ.get("API_KEY")}, ) @@ -42,12 +38,8 @@ def test_invoke_model(): file = audio_file result = model.invoke( - model="iic/SenseVoiceSmall", - credentials={ - "api_key": os.environ.get("API_KEY") - }, - file=file + model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file ) assert isinstance(result, str) - assert result == '1,2,3,4,5,6,7,8,9,10.' + assert result == "1,2,3,4,5,6,7,8,9,10." diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py index 18bd2e893ae10a..ab143c10613a88 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py @@ -15,9 +15,7 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="BAAI/bge-large-zh-v1.5", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 706316449d3142..4fe2fd8c0a3eac 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -13,20 +13,15 @@ def test_validate_credentials(): model = SparkLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='spark-1.5', - credentials={ - 'app_id': 'invalid_key' - } - ) + model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"}) model.validate_credentials( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - } + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, ) @@ -34,24 +29,17 @@ def test_invoke_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -62,23 +50,16 @@ def test_invoke_stream_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100 + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -94,20 +75,18 @@ def test_get_num_tokens(): model = SparkLargeLanguageModel() num_tokens = model.get_num_tokens( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8e22815a86fc84..9da0df6bb3d556 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = SparkProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py index d703147d638be3..c03b1bae1f1574 100644 --- a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -21,40 +21,22 @@ def test_validate_credentials(): model = StepfunLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } - ) + model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}) + def test_invoke_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['Hi'], + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["Hi"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,24 +47,17 @@ def test_invoke_stream_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,10 +73,7 @@ def test_get_customizable_model_schema(): model = StepfunLargeLanguageModel() schema = model.get_customizable_model_schema( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } + model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")} ) assert isinstance(schema, AIModelEntity) @@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools(): model = StepfunLargeLanguageModel() result = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in Shanghai?", - ) + ), ], - model_parameters={ - 'temperature': 0.9, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.9, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) - assert len(result.message.tool_calls) > 0 \ No newline at end of file + assert len(result.message.tool_calls) > 0 diff --git a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py index fd8aa3f610648c..0ec4b0b7243176 100644 --- a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py +++ b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py @@ -24,13 +24,8 @@ def test_get_models(): providers = factory.get_models( model_type=ModelType.LLM, provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) logger.debug(providers) @@ -44,29 +39,21 @@ def test_get_models(): assert provider_model.model_type == ModelType.LLM providers = factory.get_models( - provider='openai', + provider="openai", provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) assert len(providers) == 1 assert isinstance(providers[0], SimpleProviderEntity) - assert providers[0].provider == 'openai' + assert providers[0].provider == "openai" def test_provider_credentials_validate(): factory = ModelProviderFactory() factory.provider_credentials_validate( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) @@ -79,4 +66,4 @@ def test__get_model_provider_map(): logger.debug(model_provider.provider_instance) assert len(model_providers) >= 1 - assert isinstance(model_providers['openai'], ModelProviderExtension) + assert isinstance(model_providers["openai"], ModelProviderExtension) diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index 698f53451779b8..06ebc2a82dc754 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -19,76 +19,61 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) + def test_invoke_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,22 +83,21 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) + def test_get_num_tokens(): model = TogetherAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 81fb676018b992..61650735f2ad3f 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -13,18 +13,10 @@ def test_validate_credentials(): model = TongyiLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"}) model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) @@ -32,22 +24,13 @@ def test_invoke_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +41,12 @@ def test_invoke_stream_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +62,14 @@ def test_get_num_tokens(): model = TongyiLargeLanguageModel() num_tokens = model.get_num_tokens( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 6145c1dc37d00b..0bc96c84e73195 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -10,12 +10,8 @@ def test_validate_provider_credentials(): provider = TongyiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py index 1b0a38d5d15a00..905e7907fde5a8 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py @@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"): response = model.invoke( model=model_name, - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ - UserPromptMessage( - content='output json data with format `{"data": "test", "code": 200, "msg": "success"}' - ) + UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}') ], model_parameters={ - 'temperature': 0.5, - 'max_tokens': 50, - 'response_format': 'JSON', + "temperature": 0.5, + "max_tokens": 50, + "response_format": "JSON", }, stream=True, - user="abc-123" + user="abc-123", ) print("=====================================") print(response) @@ -81,4 +77,4 @@ def is_json(s): json.loads(s) except ValueError: return False - return True \ No newline at end of file + return True diff --git a/api/tests/integration_tests/model_runtime/upstage/test_llm.py b/api/tests/integration_tests/model_runtime/upstage/test_llm.py index c35580a8b1ec00..bc7517acbe2601 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_llm.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_llm.py @@ -26,151 +26,113 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): # model name to gpt-3.5-turbo because of mocking - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'upstage_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"}) model.validate_credentials( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -189,57 +151,36 @@ def test_get_num_tokens(): model = UpstageLargeLanguageModel() num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 13 num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 106 diff --git a/api/tests/integration_tests/model_runtime/upstage/test_provider.py b/api/tests/integration_tests/model_runtime/upstage/test_provider.py index c33eef49b2a79e..9d83779aa00a49 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_provider.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_provider.py @@ -7,17 +7,11 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = UpstageProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py index 54135a0e748d40..8c83172fa3ff7e 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py @@ -8,41 +8,31 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = UpstageTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': 'invalid_key' - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"} ) model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = UpstageTextEmbeddingModel() result = model.invoke( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -54,14 +44,11 @@ def test_get_num_tokens(): model = UpstageTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 5 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py index 3b399d604ec910..f831c063a42630 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -14,26 +14,26 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - 'base_model_name': 'Doubao-embedding', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + "base_model_name": "Doubao-embedding", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, ) @@ -42,20 +42,17 @@ def test_invoke_model(): model = VolcengineMaaSTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -67,19 +64,16 @@ def test_get_num_tokens(): model = VolcengineMaaSTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py index 63835d0263ead0..8ff9c414046e7d 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + }, ) @@ -40,28 +40,24 @@ def test_invoke_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) @@ -73,28 +69,24 @@ def test_invoke_stream_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -102,29 +94,24 @@ def test_invoke_stream_model(): assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len( - chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True def test_get_num_tokens(): model = VolcengineMaaSLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py index d886226cf9f393..ac38340aecf7d2 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -10,13 +10,10 @@ def test_invoke_embedding_v1(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='embedding-v1', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="embedding-v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -29,13 +26,10 @@ def test_invoke_embedding_bge_large_en(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='bge-large-en', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="bge-large-en", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -48,13 +42,10 @@ def test_invoke_embedding_bge_large_zh(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='bge-large-zh', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="bge-large-zh", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -67,13 +58,10 @@ def test_invoke_embedding_tao_8k(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='tao-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="tao-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 164e8253d966ae..e2e58f15e025d8 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -17,161 +17,125 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = ErnieBotLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, ) + def test_invoke_model_ernie_bot(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_turbo(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-turbo', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-turbo", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_8k(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_4(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-4', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-4", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-3.5-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-3.5-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -181,63 +145,48 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_model_with_system(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='你是Kasumi' - ), - UserPromptMessage( - content='你是谁?' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) - assert 'kasumi' in response.message.content.lower() + assert "kasumi" in response.message.content.lower() + def test_invoke_with_search(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'disable_search': True, + "temperature": 0.7, + "top_p": 1.0, + "disable_search": True, }, stop=[], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -247,25 +196,19 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True # there should be 对不起、我不能、不支持…… - assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message) + assert "不" in total_message or "抱歉" in total_message or "无法" in total_message + def test_get_num_tokens(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.get_num_tokens( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 10 \ No newline at end of file + assert response == 10 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 8922aa18681087..337c3d2a8010dd 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -10,16 +10,8 @@ def test_validate_provider_credentials(): provider = WenxinProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha', - 'secret_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"}) provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index f0a5151f3dbac0..8e778d005a4bc3 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -8,61 +8,57 @@ from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceTextEmbeddingModel() result = model.invoke( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = XinferenceTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 47730406de94b8..48d1ae323d6ab6 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='aaaaa', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + """ Funtion calling of xinference does not support stream mode currently """ @@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # ) # assert isinstance(response, Generator) - + # call: LLMResultChunk = None # chunks = [] @@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # assert response.usage.total_tokens > 0 # assert response.message.tool_calls[0].function.name == 'get_current_weather' -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='alapaca', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = XinferenceAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index 9012c16a7e6d3b..71ac4eef7c22be 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -8,44 +8,42 @@ from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-base', - credentials={ - 'server_url': 'awdawdaw', - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + model="bge-reranker-base", + credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")}, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceRerankModel() result = model.invoke( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py index 47a5b6cae23587..4ca1b864764818 100644 --- a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = ZhinaoLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='360gpt2-pro', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - } - ) + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) def test_invoke_model(): model = ZhinaoLargeLanguageModel() response = model.invoke( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = ZhinaoLargeLanguageModel() response = model.invoke( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = ZhinaoLargeLanguageModel() num_tokens = model.get_num_tokens( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py index 87b0e6c2d9b5de..c22f797919597c 100644 --- a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhinaoProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 0f92b50cb0f350..20380513eaa789 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -18,41 +18,22 @@ def test_validate_credentials(): model = ZhipuAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['How'], + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -63,21 +44,12 @@ def test_invoke_stream_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,63 +65,45 @@ def test_get_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 + def test_get_tools_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='tools', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="tools", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) ], prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 88 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 51b9cccf2ea752..cb5bc0b20aafc1 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhipuaiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 7308c5729669c8..9c97c91ecbdd94 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -11,34 +11,19 @@ def test_validate_credentials(): model = ZhipuAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) - - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAITextEmbeddingModel() result = model.invoke( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ], - user="abc-123" + model="text_embedding", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +35,7 @@ def test_get_num_tokens(): model = ZhipuAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 41bb3daeb5e531..4dfc530010fa93 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,20 +7,17 @@ class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ request = httpx.Request( - method, - url, - params=kwargs.get('params'), - headers=kwargs.get('headers'), - cookies=kwargs.get('cookies') + method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies") ) - data = kwargs.get('data', None) - resp = json.dumps(data).encode('utf-8') if data else b'OK' + data = kwargs.get("data", None) + resp = json.dumps(data).encode("utf-8") if data else b"OK" response = httpx.Response( status_code=200, request=request, diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index ba14d365c5d2d6..83f4d70ce9ac2f 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -10,6 +10,7 @@ "user1": ["Go for a run", "Read a book"], } + class TodosResource(Resource): def get(self, username): todos = todos_data.get(username, []) @@ -32,7 +33,8 @@ def delete(self, username): return {"error": "Invalid todo index"}, 400 -api.add_resource(TodosResource, '/todos/') -if __name__ == '__main__': +api.add_resource(TodosResource, "/todos/") + +if __name__ == "__main__": app.run(port=5003, debug=True) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index f6e7b153dde7f6..09729a961eff33 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -3,37 +3,40 @@ from tests.integration_tests.tools.__mock.http import setup_http_mock tool_bundle = { - 'server_url': 'http://www.example.com/{path_param}', - 'method': 'post', - 'author': '', - 'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'}, - {'in': 'query', 'name': 'query_param'}, - {'in': 'cookie', 'name': 'cookie_param'}, - {'in': 'header', 'name': 'header_param'}, - ], - 'requestBody': { - 'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}} - }, - 'parameters': [] + "server_url": "http://www.example.com/{path_param}", + "method": "post", + "author": "", + "openapi": { + "parameters": [ + {"in": "path", "name": "path_param"}, + {"in": "query", "name": "query_param"}, + {"in": "cookie", "name": "cookie_param"}, + {"in": "header", "name": "header_param"}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}} + }, + }, + "parameters": [], } parameters = { - 'path_param': 'p_param', - 'query_param': 'q_param', - 'cookie_param': 'c_param', - 'header_param': 'h_param', - 'body_param': 'b_param', + "path_param": "p_param", + "query_param": "q_param", + "cookie_param": "c_param", + "header_param": "h_param", + "body_param": "b_param", } def test_api_tool(setup_http_mock): - tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'})) + tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) headers = tool.assembling_request(parameters) response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert '/p_param' == response.request.url.path - assert b'query_param=q_param' == response.request.url.query - assert 'h_param' == response.request.headers.get('header_param') - assert 'application/json' == response.request.headers.get('content-type') - assert 'cookie_param=c_param' == response.request.headers.get('cookie') - assert 'b_param' in response.content.decode() + assert "/p_param" == response.request.url.path + assert b"query_param=q_param" == response.request.url.query + assert "h_param" == response.request.headers.get("header_param") + assert "application/json" == response.request.headers.get("content-type") + assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py index 2811bc816dadbe..2dfce749b3e16f 100644 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -7,16 +7,17 @@ ToolManager.clear_builtin_providers_cache() provider_generator = ToolManager.list_builtin_providers() -@pytest.mark.parametrize('name', provider_names) + +@pytest.mark.parametrize("name", provider_names) def test_tool_providers(benchmark, name): """ Test that all tool providers can be loaded """ - + def test(generator): try: return next(generator) except StopIteration: return None - - benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) \ No newline at end of file + + benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) diff --git a/api/tests/integration_tests/utils/parent_class.py b/api/tests/integration_tests/utils/parent_class.py index 39fc95256e512b..6a6de1cc41aaf2 100644 --- a/api/tests/integration_tests/utils/parent_class.py +++ b/api/tests/integration_tests/utils/parent_class.py @@ -3,4 +3,4 @@ def __init__(self, name): self.name = name def get_name(self): - return self.name \ No newline at end of file + return self.name diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 256c9a911f104f..7d32f5ae66f5df 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -7,26 +7,26 @@ def test_loading_subclass_from_source(): current_path = os.getcwd() module = load_single_subclass_from_source( - module_name='ChildClass', - script_path=os.path.join(current_path, 'child_class.py'), - parent_type=ParentClass) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass + ) + assert module and module.__name__ == "ChildClass" def test_load_import_module_from_source(): current_path = os.getcwd() module = import_module_from_source( - module_name='ChildClass', - py_file_path=os.path.join(current_path, 'child_class.py')) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py") + ) + assert module and module.__name__ == "ChildClass" def test_lazy_loading_subclass_from_source(): current_path = os.getcwd() clz = load_single_subclass_from_source( - module_name='LazyLoadChildClass', - script_path=os.path.join(current_path, 'lazy_load_class.py'), + module_name="LazyLoadChildClass", + script_path=os.path.join(current_path, "lazy_load_class.py"), parent_type=ParentClass, - use_lazy_loader=True) - instance = clz('dify') - assert instance.get_name() == 'dify' + use_lazy_loader=True, + ) + instance = clz("dify") + assert instance.get_name() == "dify" diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index f8165cba9468aa..571c1e3d440508 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,11 +13,15 @@ class MockTcvectordbClass: - - def VectorDBClient(self, url=None, username='', key='', - read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, - timeout=5, - adapter: HTTPAdapter = None): + def VectorDBClient( + self, + url=None, + username="", + key="", + read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, + timeout=5, + adapter: HTTPAdapter = None, + ): self._conn = None self._read_consistency = read_consistency @@ -26,105 +30,96 @@ def list_databases(self) -> list[Database]: Database( conn=self._conn, read_consistency=self._read_consistency, - name='dify', - )] + name="dify", + ) + ] def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: return [] def drop_collection(self, name: str, timeout: Optional[float] = None): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def create_collection( - self, - name: str, - shard: int, - replicas: int, - description: str, - index: Index, - embedding: Embedding = None, - timeout: float = None, + self, + name: str, + shard: int, + replicas: int, + description: str, + index: Index, + embedding: Embedding = None, + timeout: float = None, ) -> Collection: - return Collection(self, name, shard, replicas, description, index, embedding=embedding, - read_consistency=self._read_consistency, timeout=timeout) - - def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: - collection = Collection( + return Collection( self, name, - shard=1, - replicas=2, - description=name, - timeout=timeout + shard, + replicas, + description, + index, + embedding=embedding, + read_consistency=self._read_consistency, + timeout=timeout, ) + + def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: + collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) return collection def collection_upsert( - self, - documents: list[Document], - timeout: Optional[float] = None, - build_index: bool = True, - **kwargs + self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def collection_search( - self, - vectors: list[list[float]], - filter: Filter = None, - params=None, - retrieve_vector: bool = False, - limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + vectors: list[list[float]], + filter: Filter = None, + params=None, + retrieve_vector: bool = False, + limit: int = 10, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[list[dict]]: - return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]] + return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] def collection_query( - self, - document_ids: Optional[list] = None, - retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + document_ids: Optional[list] = None, + retrieve_vector: bool = False, + limit: Optional[int] = None, + offset: Optional[int] = None, + filter: Optional[Filter] = None, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[dict]: - return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}] + return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] def collection_delete( - self, - document_ids: list[str] = None, - filter: Filter = None, - timeout: float = None, + self, + document_ids: list[str] = None, + filter: Filter = None, + timeout: float = None, ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient) - monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases) - monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection) - monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections) - monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection) - monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection) - monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert) - monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search) - monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query) - monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) + monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) + monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) + monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) + monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) + monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) + monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) yield diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py index d6067af73b70bd..970b98edc3d83b 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -26,6 +26,7 @@ def __init__(self): def run_all_tests(self): self.vector.delete() return super().run_all_tests() - + + def test_chroma_vector(setup_mock_redis): - AnalyticdbVectorTest().run_all_tests() \ No newline at end of file + AnalyticdbVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/tests/integration_tests/vdb/chroma/test_chroma.py index 033f9a54da678c..ac7b5cbda45b23 100644 --- a/api/tests/integration_tests/vdb/chroma/test_chroma.py +++ b/api/tests/integration_tests/vdb/chroma/test_chroma.py @@ -14,13 +14,13 @@ def __init__(self): self.vector = ChromaVector( collection_name=self.collection_name, config=ChromaConfig( - host='localhost', + host="localhost", port=8000, tenant=chromadb.DEFAULT_TENANT, database=chromadb.DEFAULT_DATABASE, auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", auth_credentials="difyai123456", - ) + ), ) def search_by_full_text(self): diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py index b1c1cc10d9375d..2a0c1bb0389187 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -8,16 +8,11 @@ class ElasticSearchVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = ElasticSearchVector( index_name=self.collection_name.lower(), - config=ElasticSearchConfig( - host='http://localhost', - port='9200', - username='elastic', - password='elastic' - ), - attributes=self.attributes + config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 9c0917ef307d0f..7b5f19ea629f3c 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -12,11 +12,11 @@ def __init__(self): self.vector = MilvusVector( collection_name=self.collection_name, config=MilvusConfig( - host='localhost', + host="localhost", port=19530, - user='root', - password='Milvus', - ) + user="root", + password="Milvus", + ), ) def search_by_full_text(self): @@ -25,7 +25,7 @@ def search_by_full_text(self): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/tests/integration_tests/vdb/myscale/test_myscale.py index b6260d549ac99a..55b2fde4276105 100644 --- a/api/tests/integration_tests/vdb/myscale/test_myscale.py +++ b/api/tests/integration_tests/vdb/myscale/test_myscale.py @@ -21,7 +21,7 @@ def __init__(self): ) def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index ea1e05da9007f4..a99b81d41eba78 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -29,54 +29,55 @@ def setup_method(self): self.example_doc_id = "example_doc_id" self.vector = OpenSearchVector( collection_name=self.collection_name, - config=OpenSearchConfig( - host='localhost', - port=9200, - user='admin', - password='password', - secure=False - ) + config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), ) self.vector._client = MagicMock() - @pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [ - ({ - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] - } - }, 1, "example_doc_id"), - ({ - 'hits': { - 'total': {'value': 0}, - 'hits': [] - } - }, 0, None) - ]) + @pytest.mark.parametrize( + "search_response, expected_length, expected_doc_id", + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): self.vector._client.search.return_value = search_response hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == expected_length if expected_length > 0: - assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id + assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id def test_search_by_vector(self): vector = [0.1] * 128 mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ + "hits": { + "total": {"value": 1}, + "hits": [ { - '_source': { + "_source": { Field.CONTENT_KEY.value: get_example_text(), - Field.METADATA_KEY.value: {"document_id": self.example_doc_id} + Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, }, - '_score': 1.0 + "_score": 1.0, } - ] + ], } } self.vector._client.search.return_value = mock_response @@ -85,53 +86,45 @@ def test_search_by_vector(self): print("Hits by vector:", hits_by_vector) print("Expected document ID:", self.example_doc_id) - print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits") + print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" - assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ - f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" + assert ( + hits_by_vector[0].metadata["document_id"] == self.example_doc_id + ), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" def test_get_ids_by_metadata_field(self): - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" def test_add_texts(self): - self.vector._client.index.return_value = {'result': 'created'} + self.vector._client.index.return_value = {"result": "created"} doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" + @pytest.mark.usefixtures("setup_mock_redis") class TestOpenSearchVectorWithRedis: @@ -141,11 +134,11 @@ def setup_method(self): def test_search_by_full_text(self): self.tester.setup_method() search_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], } } expected_length = 1 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py index e6ce8aab3db173..6b33217d157ea7 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -12,13 +12,13 @@ def __init__(self): self.vector = PGVectoRS( collection_name=self.collection_name.lower(), config=PgvectoRSConfig( - host='localhost', + host="localhost", port=5431, - user='postgres', - password='difyai123456', - database='dify', + user="postgres", + password="difyai123456", + database="dify", ), - dim=128 + dim=128, ) def search_by_full_text(self): @@ -27,8 +27,9 @@ def search_by_full_text(self): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 + def test_pgvecot_rs(setup_mock_redis): PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 34beb25d450618..61d9a9e712aade 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -8,14 +8,14 @@ class QdrantVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = QdrantVector( collection_name=self.collection_name, group_id=self.dataset_id, config=QdrantConfig( - endpoint='http://localhost:6333', - api_key='difyai123456', - ) + endpoint="http://localhost:6333", + api_key="difyai123456", + ), ) diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py index 8937fe0ea16c50..1b9466e27f4c8f 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py @@ -7,18 +7,22 @@ mock_client = MagicMock() mock_client.list_databases.return_value = [{"name": "test"}] + class TencentVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.vector = TencentVector("dify", TencentConfig( - url="http://127.0.0.1", - api_key="dify", - timeout=30, - username="dify", - database="dify", - shard=1, - replicas=2, - )) + self.vector = TencentVector( + "dify", + TencentConfig( + url="http://127.0.0.1", + api_key="dify", + timeout=30, + username="dify", + database="dify", + shard=1, + replicas=2, + ), + ) def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) @@ -28,8 +32,6 @@ def search_by_full_text(self): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 -def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock): - TencentVectorTest().run_all_tests() - - +def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock): + TencentVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index cb35822709fa1a..a11cd225b3ba37 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -10,7 +10,7 @@ def get_example_text() -> str: - return 'test_text' + return "test_text" def get_example_document(doc_id: str) -> Document: @@ -21,7 +21,7 @@ def get_example_document(doc_id: str) -> Document: "doc_hash": doc_id, "document_id": doc_id, "dataset_id": doc_id, - } + }, ) return doc @@ -45,7 +45,7 @@ class AbstractVectorTest: def __init__(self): self.vector = None self.dataset_id = str(uuid.uuid4()) - self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test' + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.example_doc_id = str(uuid.uuid4()) self.example_embedding = [1.001 * i for i in range(128)] @@ -58,12 +58,12 @@ def create_vector(self) -> None: def search_by_vector(self): hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 - assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 1 - assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id def delete_vector(self): self.vector.delete() @@ -76,14 +76,14 @@ def add_texts(self) -> list[str]: documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] embeddings = [self.example_embedding] * batch_size self.vector.add_texts(documents=documents, embeddings=embeddings) - return [doc.metadata['doc_id'] for doc in documents] + return [doc.metadata["doc_id"] for doc in documents] def text_exists(self): assert self.vector.text_exists(self.example_doc_id) def get_ids_by_metadata_field(self): with pytest.raises(NotImplementedError): - self.vector.get_ids_by_metadata_field(key='key', value='value') + self.vector.get_ids_by_metadata_field(key="key", value="value") def run_all_tests(self): self.create_vector() diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 18e00dbeddbd77..2a5320c7d5e752 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -10,15 +10,15 @@ @pytest.fixture def tidb_vector(): return TiDBVector( - collection_name='test_collection', + collection_name="test_collection", config=TiDBVectorConfig( host="xxx.eu-central-1.xxx.aws.tidbcloud.com", port="4000", user="xxx.root", password="xxxxxx", database="dify", - program_name="langgenius/dify" - ) + program_name="langgenius/dify", + ), ) @@ -40,7 +40,7 @@ def search_by_full_text(self): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 0 @@ -50,12 +50,12 @@ def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_ @pytest.fixture def mock_session(): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session: + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session: yield mock_session @pytest.fixture def setup_tidbvector_mock(tidb_vector, mock_session): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'): - with patch.object(tidb_vector._engine, 'connect'): + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"): + with patch.object(tidb_vector._engine, "connect"): yield tidb_vector diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 3d540cee32bfe2..a6f55420d312ee 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -8,14 +8,14 @@ class WeaviateVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = WeaviateVector( collection_name=self.collection_name, config=WeaviateConfig( - endpoint='http://localhost:8080', - api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih', + endpoint="http://localhost:8080", + api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih", ), - attributes=self.attributes + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 51398ccb329d81..6fb8c86b82a132 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -7,25 +7,22 @@ from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], - code: str, inputs: dict) -> dict: + def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict: # invoke directly match language: case CodeLanguage.PYTHON3: - return { - "result": 3 - } + return {"result": 3} case CodeLanguage.JINJA2: - return { - "result": Template(code).render(inputs) - } + return {"result": Template(code).render(inputs)} case _: raise Exception("Language not supported") + @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): if not MOCK: diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index beb5c040097b65..cfc47bcad4675e 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -6,38 +6,32 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ - if url == 'http://404.com': - response = httpx.Response( - status_code=404, - request=httpx.Request(method, url), - content=b'Not Found' - ) + if url == "http://404.com": + response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found") return response # get data, files - data = kwargs.get('data', None) - files = kwargs.get('files', None) + data = kwargs.get("data", None) + files = kwargs.get("files", None) if data is not None: - resp = dumps(data).encode('utf-8') + resp = dumps(data).encode("utf-8") elif files is not None: - resp = dumps(files).encode('utf-8') + resp = dumps(files).encode("utf-8") else: - resp = b'OK' + resp = b"OK" response = httpx.Response( - status_code=200, - request=httpx.Request(method, url), - headers=kwargs.get('headers', {}), - content=resp + status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp ) return response diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index ae6e7ceaa71295..44dcf9a10fa2d7 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -2,10 +2,10 @@ from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor -CODE_LANGUAGE = 'unsupported_language' +CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): with pytest.raises(CodeExecutionException) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code='', inputs={}) - assert str(e.value) == f'Unsupported language {CODE_LANGUAGE}' + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 0757caba7b4936..09fcb68cf032d4 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -9,8 +9,8 @@ def test_javascript_plain(): code = 'console.log("Hello World")' - result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result_message == 'Hello World\n' + result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result_message == "Hello World\n" def test_javascript_json(): @@ -18,15 +18,18 @@ def test_javascript_json(): obj = {'Hello': 'World'} console.log(JSON.stringify(obj)) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello":"World"}\n' def test_javascript_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(), - inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} + language=CODE_LANGUAGE, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + def test_javascript_get_runner_script(): runner_script = NodeJsTemplateTransformer.get_runner_script() diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index 425f4cbdd4b9be..94903cf79688e5 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -7,21 +7,24 @@ def test_jinja2(): - template = 'Hello {{template}}' - inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8') - code = (Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._code_placeholder, template) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, - preload=Jinja2TemplateTransformer.get_preload_script(), - code=code) - assert result == '<>Hello World<>\n' + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" def test_jinja2_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code='Hello {{template}}', inputs={'template': 'World'}) - assert result == {'result': 'Hello World'} + language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} def test_jinja2_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index 9d7e86cd68c720..cbe4a5d335a7c8 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -10,8 +10,8 @@ def test_python3_plain(): code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result == 'Hello World\n' + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == "Hello World\n" def test_python3_json(): @@ -19,14 +19,15 @@ def test_python3_json(): import json print(json.dumps({'Hello': 'World'})) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello": "World"}\n' def test_python3_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} + language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} + ) + assert result == {"result": "HelloWorld"} def test_python3_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 5c952585208d7b..6f5421e108f0c7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -9,137 +9,134 @@ from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -CODE_MAX_STRING_LENGTH = int(getenv('CODE_MAX_STRING_LENGTH', '10000')) +CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { - 'outputs': { - 'result': { - 'type': 'number', + "id": "1", + "data": { + "outputs": { + "result": { + "type": "number", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] == 3 + assert result.outputs["result"] == 3 assert result.error is None -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code_output_validator(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "result": { "type": "string", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == 'Output variable `result` must be a string' + assert result.error == "Output variable `result` must be a string" + def test_execute_code_output_validator_depth(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "string_validator": { "type": "string", @@ -168,29 +165,26 @@ def main(args1: int, args2: int) -> dict: "depth": { "type": "number", } - } + }, } - } - } - } + }, + }, + }, }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result @@ -199,14 +193,7 @@ def main(args1: int, args2: int) -> dict: "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -218,14 +205,7 @@ def main(args1: int, args2: int) -> dict: "string_validator": 1, "number_array_validator": ["1", "2", "3", "3.333"], "string_array_validator": [1, 2, 3], - "object_validator": { - "result": "1", - "depth": { - "depth": { - "depth": "1" - } - } - } + "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, } # validate @@ -238,34 +218,20 @@ def main(args1: int, args2: int) -> dict: "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate with pytest.raises(ValueError): node._transform_result(result, node.node_data.outputs) - + # construct result result = { "number_validator": 1, "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333] * 2000, "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -274,58 +240,59 @@ def main(args1: int, args2: int) -> dict: def test_execute_code_output_object_list(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "object_list": { "type": "array[object]", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + ] } # validate @@ -333,13 +300,18 @@ def main(args1: int, args2: int) -> dict: # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }, 1] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + 1, + ] } # validate diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index a1354bd6a5f220..acb616b3256ac3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,322 +9,337 @@ from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock BASIC_NODE_DATA = { - 'tenant_id': '1', - 'app_id': '1', - 'workflow_id': '1', - 'user_id': '1', - 'user_from': UserFrom.ACCOUNT, - 'invoke_from': InvokeFrom.WEB_APP, + "tenant_id": "1", + "app_id": "1", + "workflow_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.WEB_APP, } # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(['a', 'b123', 'args1'], 1) -pool.add(['a', 'b123', 'args2'], 2) +pool.add(["a", "b123", "args1"], 1) +pool.add(["a", "b123", "args2"], 2) -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_no_auth(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'custom', - 'api_key': 'Auth', - 'header': 'X-Auth', + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "custom", + "api_key": "Auth", + "header": "X-Auth", + }, }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com/{{#a.b123.args2#}}', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com/{{#a.b123.args2#}}", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}", + "params": "A:b\nTemplate:{{#a.b123.args2#}}", + "body": None, }, - 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', - 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'Template=2' in data - assert 'X-Header: 123' in data - assert 'X-Header2: 2' in data + assert "?A=b" in data + assert "Template=2" in data + assert "X-Header: 123" in data + assert "X-Header2: 2" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'json', - 'data': '{"a": "{{#a.b123.args1#}}"}' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert '{"a": "1"}' in data - assert 'X-Header: 123' in data + assert "X-Header: 123" in data def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'x-www-form-urlencoded', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'a=1&b=2' in data - assert 'X-Header: 123' in data + assert "a=1&b=2" in data + assert "X-Header: 123" in data def test_form_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'form-data', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert 'form-data; name="a"' in data - assert '1' in data + assert "1" in data assert 'form-data; name="b"' in data - assert '2' in data - assert 'X-Header: 123' in data + assert "2" in data + assert "X-Header: 123" in data def test_none_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'none', - 'data': '123123123' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "none", "data": "123123123"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'X-Header: 123' in data - assert '123123123' not in data + assert "X-Header: 123" in data + assert "123123123" not in data def test_mock_404(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://404.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://404.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "body": None, + "params": "", + "headers": "X-Header:123", }, - 'body': None, - 'params': '', - 'headers': 'X-Header:123', - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert 404 == resp.get('status_code') - assert 'Not Found' in resp.get('body') + assert 404 == resp.get("status_code") + assert "Not Found" in resp.get("body") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, - }, - 'params': 'Referer:http://example1.com\nRedirect:http://example2.com', - 'headers': 'Referer:http://example3.com\nRedirect:http://example4.com', - 'body': { - 'type': 'form-data', - 'data': 'Referer:http://example5.com\nRedirect:http://example6.com' + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "params": "Referer:http://example1.com\nRedirect:http://example2.com", + "headers": "Referer:http://example3.com\nRedirect:http://example4.com", + "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert urlencode({'Redirect': 'http://example2.com'}) in result.process_data.get('request') - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get('request') - assert 'http://example3.com' == resp.get('headers').get('referer') + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") + assert "http://example3.com" == resp.get("headers").get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 1b27af5af7d793..6bab83a0191bca 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -23,90 +23,71 @@ from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm(setup_openai_mock): node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'prompt_template': [ - { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.' - }, - { - 'role': 'user', - 'text': '{{#sys.query#}}' - } + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather today?', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') - - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - - provider_instance = ModelProviderFactory().get_provider_instance('openai') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") + + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -118,112 +99,97 @@ def test_execute_llm(setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['text'] is not None - assert result.outputs['usage']['total_tokens'] > 0 + assert result.outputs["text"] is not None + assert result.outputs["usage"]["total_tokens"] > 0 + -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - 'prompt_config': { - 'jinja2_variables': [{ - 'variable': 'sys_query', - 'value_selector': ['sys', 'query'] - }, { - 'variable': 'output', - 'value_selector': ['abc', 'output'] - }] - }, - 'prompt_template': [ + "prompt_template": [ { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', - 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', - 'edition_type': 'jinja2' + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", }, { - 'role': 'user', - 'text': '{{#sys.query#}}', - 'jinja2_text': '{{sys_query}}', - 'edition_type': 'basic' - } + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather today?', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') - - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - - provider_instance = ModelProviderFactory().get_provider_instance('openai') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") + + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -235,5 +201,5 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) + assert "sunny" in json.dumps(result.process_data) + assert "what's the weather today?" in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index e32fa59df36ee7..ca2bae5c536090 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -26,29 +26,25 @@ def get_mocked_fetch_model_config( - provider: str, model: str, mode: str, + provider: str, + model: str, + mode: str, credentials: dict, ): provider_instance = ModelProviderFactory().get_provider_instance(provider) model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) model_config = ModelConfigWithCredentialsEntity( @@ -58,268 +54,268 @@ def get_mocked_fetch_model_config( credentials=credentials, parameters={}, model_schema=model_type_instance.get_model_schema(model), - provider_model_bundle=provider_model_bundle + provider_model_bundle=provider_model_bundle, ) return MagicMock(return_value=(model_instance, model_config)) + def get_mocked_fetch_memory(memory_text: str): class MemoryMock: - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None): + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ): return memory_text return MagicMock(return_value=MemoryMock()) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ Test function calling for parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'instruction': '', - 'reasoning_mode': 'function_call', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "instruction": "", + "reasoning_mode": "function_call", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'function_call', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "function_call", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None process_data = result.process_data - process_data.get('prompts') + process_data.get("prompts") + + for prompt in process_data.get("prompts"): + if prompt.get("role") == "system": + assert "what's the weather in SF" in prompt.get("text") - for prompt in process_data.get('prompts'): - if prompt.get('role') == 'system': - assert 'what\'s the weather in SF' in prompt.get('text') -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo-instruct", + mode="completion", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - assert len(result.process_data.get('prompts')) == 1 - assert 'SF' in result.process_data.get('prompts')[0].get('text') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert len(result.process_data.get("prompts")) == 1 + assert "SF" in result.process_data.get("prompts")[0].get("text") + def test_extract_json_response(): """ @@ -327,35 +323,30 @@ def test_extract_json_response(): """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) result = node._extract_complete_json_response(""" @@ -366,83 +357,77 @@ def test_extract_json_response(): hello world. """) - assert result['location'] == 'kawaii' + assert result["location"] == "kawaii" -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': { - 'window': { - 'enabled': True, - 'size': 50 - } - }, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": {"window": {"enabled": True, "size": 50}}, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) - node._fetch_memory = get_mocked_fetch_memory('customized memory') + node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") latest_role = None for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') - elif prompt.get('role') == 'system': - assert 'customized memory' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + elif prompt.get("role") == "system": + assert "customized memory" in prompt.get("text") if latest_role is not None: - assert latest_role != prompt.get('role') + assert latest_role != prompt.get("role") - if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') + if prompt.get("role") in ["user", "assistant"]: + latest_role = prompt.get("role") diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 781dfbc50fdba3..617b6370c9f410 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -8,42 +8,39 @@ from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = '''{{args2}}''' + code = """{{args2}}""" node = TemplateTransformNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER, config={ - 'id': '1', - 'data': { - 'title': '123', - 'variables': [ + "id": "1", + "data": { + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'template': code, - } - } + "template": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 3) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 3) + # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['output'] == '3' + assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 01d62280e837b3..29c1efa8e78b2a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -7,78 +7,79 @@ def test_tool_variable_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], '1+1') + pool.add(["1", "123", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'variable', - 'value': ['1', '123', 'args1'], + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] + def test_tool_mixed_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', 'args1'], '1+1') + pool.add(["1", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'mixed', - 'value': '{{#1.args1#}}', + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "mixed", + "value": "{{#1.args1#}}", } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] \ No newline at end of file + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 949a5a17693457..39f313b51344b7 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -6,18 +6,21 @@ from configs.app_config import DifyConfig -EXAMPLE_ENV_FILENAME = '.env' +EXAMPLE_ENV_FILENAME = ".env" @pytest.fixture def example_env_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME) - file_path.write_text(dedent( - """ + file_path.write_text( + dedent( + """ CONSOLE_API_URL=https://example.com CONSOLE_WEB_URL=https://example.com - """)) + """ + ) + ) return str(file_path) @@ -29,7 +32,7 @@ def test_dify_config_undefined_entry(example_env_file): # entries not defined in app settings with pytest.raises(TypeError): # TypeError: 'AppSettings' object is not subscriptable - assert config['LOG_LEVEL'] == 'INFO' + assert config["LOG_LEVEL"] == "INFO" def test_dify_config(example_env_file): @@ -37,10 +40,10 @@ def test_dify_config(example_env_file): config = DifyConfig(_env_file=example_env_file) # constant values - assert config.COMMIT_SHA == '' + assert config.COMMIT_SHA == "" # default values - assert config.EDITION == 'SELF_HOSTED' + assert config.EDITION == "SELF_HOSTED" assert config.API_COMPRESSION_ENABLED is False assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 @@ -48,36 +51,36 @@ def test_dify_config(example_env_file): # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_flask_configs(example_env_file): - flask_app = Flask('app') + flask_app = Flask("app") # clear system environment variables os.environ.clear() flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings - assert config['LOG_LEVEL'] == 'INFO' - assert config['COMMIT_SHA'] == '' - assert config['EDITION'] == 'SELF_HOSTED' - assert config['API_COMPRESSION_ENABLED'] is False - assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0 - assert config['TESTING'] == False + assert config["LOG_LEVEL"] == "INFO" + assert config["COMMIT_SHA"] == "" + assert config["EDITION"] == "SELF_HOSTED" + assert config["API_COMPRESSION_ENABLED"] is False + assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0 + assert config["TESTING"] == False # value from env file - assert config['CONSOLE_API_URL'] == 'https://example.com' + assert config["CONSOLE_API_URL"] == "https://example.com" # fallback to alias choices value as CONSOLE_API_URL - assert config['FILES_URL'] == 'https://example.com' + assert config["FILES_URL"] == "https://example.com" - assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify' - assert config['SQLALCHEMY_ENGINE_OPTIONS'] == { - 'connect_args': { - 'options': '-c timezone=UTC', + assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify" + assert config["SQLALCHEMY_ENGINE_OPTIONS"] == { + "connect_args": { + "options": "-c timezone=UTC", }, - 'max_overflow': 10, - 'pool_pre_ping': False, - 'pool_recycle': 3600, - 'pool_size': 30, + "max_overflow": 10, + "pool_pre_ping": False, + "pool_recycle": 3600, + "pool_size": 30, } - assert config['CONSOLE_WEB_URL']=='https://example.com' - assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com'] - assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*'] + assert config["CONSOLE_WEB_URL"] == "https://example.com" + assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"] + assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"] diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index afd0fa50b590f8..0824c8e9e978e2 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -17,31 +17,31 @@ def test_string_variable(): - test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, StringVariable) def test_integer_variable(): - test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} + test_data = {"value_type": "number", "name": "test_int", "value": 42} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, IntegerVariable) def test_float_variable(): - test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, FloatVariable) def test_secret_variable(): - test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, SecretVariable) def test_invalid_value_type(): - test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} with pytest.raises(VariableError): factory.build_variable_from_mapping(test_data) @@ -49,51 +49,51 @@ def test_invalid_value_type(): def test_build_a_blank_string(): result = factory.build_variable_from_mapping( { - 'value_type': 'string', - 'name': 'blank', - 'value': '', + "value_type": "string", + "name": "blank", + "value": "", } ) assert isinstance(result, StringVariable) - assert result.value == '' + assert result.value == "" def test_build_a_object_variable_with_none_value(): var = factory.build_segment( { - 'key1': None, + "key1": None, } ) assert isinstance(var, ObjectSegment) - assert var.value['key1'] is None + assert var.value["key1"] is None def test_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'test_object', - 'description': 'Description of the variable.', - 'value': { - 'key1': 'text', - 'key2': 2, + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, }, } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value['key1'], str) - assert isinstance(variable.value['key2'], int) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) def test_array_string_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[string]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ - 'text', - 'text', + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", ], } variable = factory.build_variable_from_mapping(mapping) @@ -104,11 +104,11 @@ def test_array_string_variable(): def test_array_number_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[number]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ 1, 2.0, ], @@ -121,18 +121,18 @@ def test_array_number_variable(): def test_array_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[object]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ { - 'key1': 'text', - 'key2': 1, + "key1": "text", + "key2": 1, }, { - 'key1': 'text', - 'key2': 1, + "key1": "text", + "key2": 1, }, ], } @@ -140,19 +140,19 @@ def test_array_object_variable(): assert isinstance(variable, ArrayObjectVariable) assert isinstance(variable.value[0], dict) assert isinstance(variable.value[1], dict) - assert isinstance(variable.value[0]['key1'], str) - assert isinstance(variable.value[0]['key2'], int) - assert isinstance(variable.value[1]['key1'], str) - assert isinstance(variable.value[1]['key2'], int) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) def test_variable_cannot_large_than_5_kb(): with pytest.raises(VariableError): factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'test_text', - 'value': 'a' * 1024 * 6, + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 6, } ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 50d991316d9d26..7cc339d21200a1 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -7,20 +7,20 @@ def test_segment_group_to_text(): variable_pool = VariablePool( system_variables={ - SystemVariableKey('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[ - SecretVariable(name='secret_key', value='fake-secret-key'), + SecretVariable(name="secret_key", value="fake-secret-key"), ], ) - variable_pool.add(('node_id', 'custom_query'), 'fake-user-query') + variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( - 'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.' + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.' + assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." assert ( segments_group.log == f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}." @@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group(): user_inputs={}, environment_variables=[], ) - template = 'Hello, world!' + template = "Hello, world!" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, world!' - assert segments_group.log == 'Hello, world!' + assert segments_group.text == "Hello, world!" + assert segments_group.log == "Hello, world!" def test_convert_variable_to_segment_group(): variable_pool = VariablePool( system_variables={ - SystemVariableKey('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[], ) - template = '{{#sys.user_id#}}' + template = "{{#sys.user_id#}}" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'fake-user-id' - assert segments_group.log == 'fake-user-id' - assert segments_group.value == [StringSegment(value='fake-user-id')] + assert segments_group.text == "fake-user-id" + assert segments_group.log == "fake-user-id" + assert segments_group.value == [StringSegment(value="fake-user-id")] diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 1f45c15f8712a9..b3f0ae626cf24b 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -13,60 +13,60 @@ def test_frozen_variables(): - var = StringVariable(name='text', value='text') + var = StringVariable(name="text", value="text") with pytest.raises(ValidationError): - var.value = 'new value' + var.value = "new value" - int_var = IntegerVariable(name='integer', value=42) + int_var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): int_var.value = 100 - float_var = FloatVariable(name='float', value=3.14) + float_var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): float_var.value = 2.718 - secret_var = SecretVariable(name='secret', value='secret_value') + secret_var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): - secret_var.value = 'new_secret_value' + secret_var.value = "new_secret_value" def test_variable_value_type_immutable(): with pytest.raises(ValidationError): - StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text') + StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text") with pytest.raises(ValidationError): - StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'}) + StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"}) - var = IntegerVariable(name='integer', value=42) + var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = SecretVariable(name='secret', value='secret_value') + var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) def test_object_variable_to_object(): var = ObjectVariable( - name='object', + name="object", value={ - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': ['value5_1', 42, {}], + "key2": ["value5_1", 42, {}], }, ) assert var.to_object() == { - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': [ - 'value5_1', + "key2": [ + "value5_1", 42, {}, ], @@ -74,11 +74,11 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var = StringVariable(name='text', value='text') - assert var.to_object() == 'text' - var = IntegerVariable(name='integer', value=42) + var = StringVariable(name="text", value="text") + assert var.to_object() == "text" + var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) assert var.to_object() == 3.14 - var = SecretVariable(name='secret', value='secret_value') - assert var.to_object() == 'secret_value' + var = SecretVariable(name="secret", value="secret_value") + assert var.to_object() == "secret_value" diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d917bb10036faa..7a0bc70c63eeab 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -4,17 +4,17 @@ from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request -@patch('httpx.request') +@patch("httpx.request") def test_successful_request(mock_request): mock_response = MagicMock() mock_response.status_code = 200 mock_request.return_value = mock_response - response = make_request('GET', 'http://example.com') + response = make_request("GET", "http://example.com") assert response.status_code == 200 -@patch('httpx.request') +@patch("httpx.request") def test_retry_exceed_max_retries(mock_request): mock_response = MagicMock() mock_response.status_code = 500 @@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request): mock_request.side_effect = side_effects try: - make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) raise AssertionError("Expected Exception not raised") except Exception as e: assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" -@patch('httpx.request') +@patch("httpx.request") def test_retry_logic_success(mock_request): side_effects = [] @@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request): mock_request.side_effect = side_effects - response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES) assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 - assert mock_request.call_args_list[0][1].get('method') == 'GET' + assert mock_request.call_args_list[0][1].get("method") == "GET" diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py index 68334fde82f350..5b159b49b61f37 100644 --- a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -21,18 +21,18 @@ def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: return _MockTextEmbedding() - model = 'embedding-v1' + model = "embedding-v1" credentials = { - 'api_key': 'xxxx', - 'secret_key': 'yyyy', + "api_key": "xxxx", + "secret_key": "yyyy", } embedding_model = WenxinTextEmbeddingModel() context_size = embedding_model._get_context_size(model, credentials) max_chunks = embedding_model._get_max_chunks(model, credentials) embedding_model._create_text_embedding = _create_text_embedding - texts = ['0123456789' for i in range(0, max_chunks * 2)] - result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + texts = ["0123456789" for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") assert len(result.embeddings) == max_chunks * 2 @@ -41,16 +41,16 @@ def get_num_tokens_by_gpt2(text: str) -> int: return GPT2Tokenizer.get_num_tokens(text) def mock_text(token_size: int) -> str: - _text = "".join(['0' for i in range(token_size)]) + _text = "".join(["0" for i in range(token_size)]) num_tokens = get_num_tokens_by_gpt2(_text) ratio = int(np.floor(len(_text) / num_tokens)) m_text = "".join([_text for i in range(ratio)]) return m_text - model = 'embedding-v1' + model = "embedding-v1" credentials = { - 'api_key': 'xxxx', - 'secret_key': 'yyyy', + "api_key": "xxxx", + "secret_key": "yyyy", } embedding_model = WenxinTextEmbeddingModel() context_size = embedding_model._get_context_size(model, credentials) @@ -71,5 +71,5 @@ def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: assert get_num_tokens_by_gpt2(text) == context_size * 2 texts = [text] - result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") assert result.usage.tokens == context_size diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index d24cd4aae98ded..24bbde6d4ebb66 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -14,39 +14,24 @@ def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_config = CompletionModelPromptTemplate( - text=prompt_template - ) + prompt_template_config = CompletionModelPromptTemplate(text=prompt_template) memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix( - user="Human", - assistant="Assistant" - ), - window=MemoryConfig.WindowConfig( - enabled=False - ) + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), ) - inputs = { - "name": "John" - } + inputs = {"name": "John"} files = [] context = "I am superman." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages(): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 1 - assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ - "#context#": context, - "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " - f"{prompt.content}" for prompt in history_prompt_messages]), - **inputs, - }) + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format( + { + "#context#": context, + "#histories#": "\n".join( + [ + f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}" + for prompt in history_prompt_messages + ] + ), + **inputs, + } + ) def test__get_chat_model_prompt_messages(get_chat_model_args): @@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): files = [] query = "Hi2." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi1."), - AssistantPromptMessage(content="Hello1!") - ] + history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert prompt_messages[5].content == query @@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): @@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg image_config={ "detail": "high", } - ) + ), ) ] @@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 assert prompt_messages[3].content[1].data == files[0].url @@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg @pytest.fixture def get_chat_model_args(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) prompt_messages = [ ChatModelMessage( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM ), - ChatModelMessage( - text="Hi.", - role=PromptMessageRole.USER - ), - ChatModelMessage( - text="Hello!", - role=PromptMessageRole.ASSISTANT - ) + ChatModelMessage(text="Hi.", role=PromptMessageRole.USER), + ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT), ] - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "I am superman." diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 9de268d7624744..0fd176e65d02f8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -18,27 +18,28 @@ def test_get_prompt(): prompt_messages = [ - SystemPromptMessage(content='System Template'), - UserPromptMessage(content='User Query'), + SystemPromptMessage(content="System Template"), + UserPromptMessage(content="User Query"), ] history_messages = [ - SystemPromptMessage(content='System Prompt 1'), - UserPromptMessage(content='User Prompt 1'), - AssistantPromptMessage(content='Assistant Thought 1'), - ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), - ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), - SystemPromptMessage(content='System Prompt 2'), - UserPromptMessage(content='User Prompt 2'), - AssistantPromptMessage(content='Assistant Thought 2'), - ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), - ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), - UserPromptMessage(content='User Prompt 3'), - AssistantPromptMessage(content='Assistant Thought 3'), + SystemPromptMessage(content="System Prompt 1"), + UserPromptMessage(content="User Prompt 1"), + AssistantPromptMessage(content="Assistant Thought 1"), + ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"), + ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"), + SystemPromptMessage(content="System Prompt 2"), + UserPromptMessage(content="User Prompt 2"), + AssistantPromptMessage(content="Assistant Thought 2"), + ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"), + ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"), + UserPromptMessage(content="User Prompt 3"), + AssistantPromptMessage(content="Assistant Thought 3"), ] # use message number instead of token for testing def side_effect_get_num_tokens(*args): return len(args[2]) + large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) @@ -46,20 +47,17 @@ def side_effect_get_num_tokens(*args): provider_model_bundle_mock.model_type_instance = large_language_model_mock model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.model = 'openai' + model_config_mock.model = "openai" model_config_mock.credentials = {} model_config_mock.provider_model_bundle = provider_model_bundle_mock - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) transform = AgentHistoryPromptTransform( model_config=model_config_mock, prompt_messages=prompt_messages, history_messages=history_messages, - memory=memory + memory=memory, ) max_token_limit = 5 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 2bcc6f42927edc..89c14463bbfb94 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -12,19 +12,15 @@ def test__calculate_rest_token(): model_schema_mock = MagicMock(spec=AIModelEntity) parameter_rule_mock = MagicMock(spec=ParameterRule) - parameter_rule_mock.name = 'max_tokens' - model_schema_mock.parameter_rules = [ - parameter_rule_mock - ] - model_schema_mock.model_properties = { - ModelPropertyKey.CONTEXT_SIZE: 62 - } + parameter_rule_mock.name = "max_tokens" + model_schema_mock.parameter_rules = [parameter_rule_mock] + model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens.return_value = 6 provider_mock = MagicMock(spec=ProviderEntity) - provider_mock.provider = 'openai' + provider_mock.provider = "openai" provider_configuration_mock = MagicMock(spec=ProviderConfiguration) provider_configuration_mock.provider = provider_mock @@ -35,11 +31,9 @@ def test__calculate_rest_token(): provider_model_bundle_mock.configuration = provider_configuration_mock model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.model = 'gpt-4' + model_config_mock.model = "gpt-4" model_config_mock.credentials = {} - model_config_mock.parameters = { - 'max_tokens': 50 - } + model_config_mock.parameters = {"max_tokens": 50} model_config_mock.model_schema = model_schema_mock model_config_mock.provider_model_bundle = provider_model_bundle_mock @@ -49,8 +43,10 @@ def test__calculate_rest_token(): rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) # Validate based on the mock configuration and expected logic - expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] - - model_config_mock.parameters['max_tokens'] - - large_language_model_mock.get_num_tokens.return_value) + expected_rest_tokens = ( + model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters["max_tokens"] + - large_language_model_mock.get_num_tokens.return_value + ) assert rest_tokens == expected_rest_tokens assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 6d6363610bb16a..c32fc2bc34813d 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_baichuan_chat_app_prompt_template_with_pcqm(): @@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_common_completion_app_prompt_template_with_pcq(): @@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_baichuan_completion_app_prompt_template_with_pcq(): @@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - print(prompt_template['prompt_template'].template) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + print(prompt_template["prompt_template"].template) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_q(): @@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] - assert prompt_template['special_variable_keys'] == ['#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"] + assert prompt_template["special_variable_keys"] == ["#query#"] def test_get_common_chat_app_prompt_template_with_cq(): @@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_p(): @@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p(): query_in_prompt=False, with_memory_prompt=False, ) - assert prompt_template['prompt_template'].template == pre_prompt + '\n' - assert prompt_template['custom_variable_keys'] == ['name'] - assert prompt_template['special_variable_keys'] == [] + assert prompt_template["prompt_template"].template == pre_prompt + "\n" + assert prompt_template["custom_variable_keys"] == ["name"] + assert prompt_template["special_variable_keys"] == [] def test__get_chat_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" memory_mock = MagicMock(spec=TokenBufferMemory) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory_mock.get_history_prompt_messages.return_value = history_prompt_messages prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( @@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages(): files=[], context=context, memory=memory_mock, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages(): with_memory_prompt=False, ) - full_inputs = {**inputs, '#context#': context} - real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + full_inputs = {**inputs, "#context#": context} + real_system_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt @@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( @@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages(): files=[], context=context, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages(): with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( - max_token_limit=2000, - human_prefix=prompt_rules.get("human_prefix", "Human"), - ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") - )} - real_prompt = prompt_template['prompt_template'].format(full_inputs) + prompt_rules = prompt_template["prompt_rules"] + full_inputs = { + **inputs, + "#context#": context, + "#query#": query, + "#histories#": memory.get_history_prompt_text( + max_token_limit=2000, + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), + ), + } + real_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_rules.get('stops') + assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 9e43b23658f41a..8d735cae86d947 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -5,20 +5,15 @@ def test_default_value(): - valid_config = { - 'host': 'localhost', - 'port': 19530, - 'user': 'root', - 'password': 'Milvus' - } + valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: MilvusConfig(**config) - assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required' + assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" config = MilvusConfig(**valid_config) assert config.secure is False - assert config.database == 'default' + assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index a8bba11e16db1f..d5a1d8f436c75a 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -9,19 +9,17 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): url = "https://firecrawl.dev" - api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-' - base_url = 'https://api.firecrawl.dev' - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=base_url) + api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" + base_url = "https://api.firecrawl.dev" + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "maxDepth": 1, "limit": 1, - 'returnOnlyUrls': False, - + "returnOnlyUrls": False, } } mocked_firecrawl = { diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index b231fe479b829b..eea584a2f8edc6 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -8,11 +8,8 @@ extractor = notion_extractor.NotionExtractor( - notion_workspace_id='x', - notion_obj_id='x', - notion_page_type='page', - tenant_id='x', - notion_access_token='x') + notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" +) def _generate_page(page_title: str): @@ -21,16 +18,10 @@ def _generate_page(page_title: str): "id": page_id, "properties": { "Page": { - "type": "title", - "title": [ - { - "type": "text", - "text": {"content": page_title}, - "plain_text": page_title - } - ] + "type": "title", + "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], } - } + }, } @@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str): return { "object": "block", "id": block_id, - "parent": { - "type": "page_id", - "page_id": page_id - }, + "parent": {"type": "page_id", "page_id": page_id}, "type": block_type, "has_children": False, block_type: { @@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str): { "type": "text", "text": {"content": block_text}, - "plain_text": block_text, - }] - } - } + "plain_text": block_text, + } + ] + }, + } def _mock_response(data): @@ -63,7 +52,7 @@ def _mock_response(data): def _remove_multiple_new_lines(text): - while '\n\n' in text: + while "\n\n" in text: text = text.replace("\n\n", "\n") return text.strip() @@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text): def test_notion_page(mocker): texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]) - ], - "next_cursor": None + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]), + ], + "next_cursor": None, } mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) page_docs = extractor._load_data_as_documents(page_id, "page") assert len(page_docs) == 1 content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1' + assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" def test_notion_database(mocker): @@ -93,10 +82,10 @@ def test_notion_database(mocker): mocked_notion_database = { "object": "list", "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None + "next_cursor": None, } mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) database_docs = extractor._load_data_as_documents(database_id, "database") assert len(database_docs) == 1 content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == '\n'.join([f'Page:{i}' for i in page_title_list]) + assert content == "\n".join([f"Page:{i}" for i in page_title_list]) diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 3024a54a4d325c..2808b5b0fad1a7 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -10,36 +10,24 @@ @pytest.fixture def lb_model_manager(): load_balancing_configs = [ - ModelLoadBalancingConfiguration( - id='id1', - name='__inherit__', - credentials={} - ), - ModelLoadBalancingConfiguration( - id='id2', - name='first', - credentials={"openai_api_key": "fake_key"} - ), - ModelLoadBalancingConfiguration( - id='id3', - name='second', - credentials={"openai_api_key": "fake_key"} - ) + ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}), + ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}), + ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}), ] lb_model_manager = LBModelManager( - tenant_id='tenant_id', - provider='openai', + tenant_id="tenant_id", + provider="openai", model_type=ModelType.LLM, - model='gpt-4', + model="gpt-4", load_balancing_configs=load_balancing_configs, - managed_credentials={"openai_api_key": "fake_key"} + managed_credentials={"openai_api_key": "fake_key"}, ) lb_model_manager.cooldown = MagicMock(return_value=None) def is_cooldown(config: ModelLoadBalancingConfiguration): - if config.id == 'id1': + if config.id == "id1": return True return False @@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager): assert lb_model_manager.in_cooldown(config3) is False start_index = 0 + def incr(key): nonlocal start_index start_index += 1 return start_index - mocker.patch('redis.Redis.incr', side_effect=incr) - mocker.patch('redis.Redis.set', return_value=None) - mocker.patch('redis.Redis.expire', return_value=None) + mocker.patch("redis.Redis.incr", side_effect=incr) + mocker.patch("redis.Redis.set", return_value=None) + mocker.patch("redis.Redis.expire", return_value=None) config = lb_model_manager.fetch_next() assert config == config2 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 072b6f100f3a16..2f4214a5801de1 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -11,62 +11,62 @@ def test__to_model_settings(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ), LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True - ) + enabled=True, + ), ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 2 - assert result[0].load_balancing_configs[0].name == '__inherit__' - assert result[0].load_balancing_configs[1].name == 'first' + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" def test__to_model_settings_only_one_lb(mocker): @@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ) ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 @@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=False - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ), LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True - ) + enabled=True, + ), ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py index 9addeeadca7bf9..279a6cdbc328e6 100644 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -5,52 +5,52 @@ def test_get_parameter_type(): - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" with pytest.raises(ValueError): - ToolParameterConverter.get_parameter_type('unsupported_type') + ToolParameterConverter.get_parameter_type("unsupported_type") def test_cast_parameter_by_type(): # string - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" # secret input - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" # select - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" # boolean - true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something'] + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] for value in true_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True - false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, ''] + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] for value in false_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False # number - assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None # unknown - assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1' - assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" + assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 44b7c852569f7e..8020674ee6e3d9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -11,29 +11,30 @@ def test_execute_answer(): node = AnswerNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'answer', - 'data': { - 'title': '123', - 'type': 'answer', - 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' - } - } + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.FILES: [], - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'weather'], 'sunny') - pool.add(['llm', 'text'], 'You are a helpful AI.') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") # Mock db.session.close() db.session.close = MagicMock() @@ -42,4 +43,4 @@ def test_execute_answer(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 87ebcb34e651b7..9535bc2186af72 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -11,134 +11,81 @@ def test_execute_if_else_result_true(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'and', - 'conditions': [ - { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'not_contains'], - 'value': 'ab' - }, - { - 'comparison_operator': 'start with', - 'variable_selector': ['start', 'start_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'end with', - 'variable_selector': ['start', 'end_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is', - 'variable_selector': ['start', 'is'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is not', - 'variable_selector': ['start', 'is_not'], - 'value': 'ab' - }, - { - 'comparison_operator': 'empty', - 'variable_selector': ['start', 'empty'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not empty', - 'variable_selector': ['start', 'not_empty'], - 'value': 'ab' - }, - { - 'comparison_operator': '=', - 'variable_selector': ['start', 'equals'], - 'value': '22' - }, - { - 'comparison_operator': '≠', - 'variable_selector': ['start', 'not_equals'], - 'value': '22' - }, - { - 'comparison_operator': '>', - 'variable_selector': ['start', 'greater_than'], - 'value': '22' - }, - { - 'comparison_operator': '<', - 'variable_selector': ['start', 'less_than'], - 'value': '22' - }, - { - 'comparison_operator': '≥', - 'variable_selector': ['start', 'greater_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': '≤', - 'variable_selector': ['start', 'less_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': 'null', - 'variable_selector': ['start', 'null'] - }, - { - 'comparison_operator': 'not null', - 'variable_selector': ['start', 'not_null'] - }, - ] - } - } + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ + { + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", + }, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, + { + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", + }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, + { + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", + }, + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.FILES: [], - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ac', 'def']) - pool.add(['start', 'contains'], 'cabcde') - pool.add(['start', 'not_contains'], 'zacde') - pool.add(['start', 'start_with'], 'abc') - pool.add(['start', 'end_with'], 'zzab') - pool.add(['start', 'is'], 'ab') - pool.add(['start', 'is_not'], 'aab') - pool.add(['start', 'empty'], '') - pool.add(['start', 'not_empty'], 'aaa') - pool.add(['start', 'equals'], 22) - pool.add(['start', 'not_equals'], 23) - pool.add(['start', 'greater_than'], 23) - pool.add(['start', 'less_than'], 21) - pool.add(['start', 'greater_than_or_equal'], 22) - pool.add(['start', 'less_than_or_equal'], 21) - pool.add(['start', 'not_null'], '1212') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") # Mock db.session.close() db.session.close = MagicMock() @@ -147,46 +94,47 @@ def test_execute_if_else_result_true(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is True + assert result.outputs["result"] is True def test_execute_if_else_result_false(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'or', - 'conditions': [ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", }, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' - } - ] - } - } + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.FILES: [], - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['1ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ab', 'def']) + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) # Mock db.session.close() db.session.close = MagicMock() @@ -195,4 +143,4 @@ def test_execute_if_else_result_false(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is False + assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 5df8c1b7639051..e26c7df642776b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -8,41 +8,41 @@ from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode -DEFAULT_NODE_ID = 'node_id' +DEFAULT_NODE_ID = "node_id" def test_overwrite_string_variable(): conversation_variable = StringVariable( id=str(uuid4()), - name='test_conversation_variable', - value='the first value', + name="test_conversation_variable", + value="the first value", ) input_variable = StringVariable( id=str(uuid4()), - name='test_string_variable', - value='the second value', + name="test_string_variable", + value="the second value", ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.OVER_WRITE.value, - 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -52,48 +52,48 @@ def test_overwrite_string_variable(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: node.run(variable_pool) mock_run.assert_called_once() - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None - assert got.value == 'the second value' - assert got.to_object() == 'the second value' + assert got.value == "the second value" + assert got.to_object() == "the second value" def test_append_variable_to_array(): conversation_variable = ArrayStringVariable( id=str(uuid4()), - name='test_conversation_variable', - value=['the first value'], + name="test_conversation_variable", + value=["the first value"], ) input_variable = StringVariable( id=str(uuid4()), - name='test_string_variable', - value='the second value', + name="test_string_variable", + value="the second value", ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.APPEND.value, - 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -103,41 +103,41 @@ def test_append_variable_to_array(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: node.run(variable_pool) mock_run.assert_called_once() - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None - assert got.to_object() == ['the first value', 'the second value'] + assert got.to_object() == ["the first value", "the second value"] def test_clear_array(): conversation_variable = ArrayStringVariable( id=str(uuid4()), - name='test_conversation_variable', - value=['the first value'], + name="test_conversation_variable", + value=["the first value"], ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.CLEAR.value, - 'input_variable_selector': [], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -145,6 +145,6 @@ def test_clear_array(): node.run(variable_pool) - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py index bbc372ed61b65a..21c2f0781d85f9 100644 --- a/api/tests/unit_tests/libs/test_pandas.py +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -3,50 +3,46 @@ def test_pandas_csv(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to csv file - csv_file_path = tmp_path.joinpath('example.csv') + csv_file_path = tmp_path.joinpath("example.csv") df1.to_csv(csv_file_path, index=False) # read from csv file - df2 = pd.read_csv(csv_file_path, on_bad_lines='skip') - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + df2 = pd.read_csv(csv_file_path, on_bad_lines="skip") + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to xlsx file - xlsx_file_path = tmp_path.joinpath('example.xlsx') + xlsx_file_path = tmp_path.joinpath("example.xlsx") df1.to_excel(xlsx_file_path, index=False) # read from xlsx file df2 = pd.read_excel(xlsx_file_path) - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data1 = {'col1': [1, 2, 3, 4, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data1) - data2 = {'col1': [6, 7, 8, 9, 10], - 'col2': ['F', 'G', 'H', 'I', 'J']} + data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]} df2 = pd.DataFrame(data2) # write to xlsx file with sheets - xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx') - sheet1 = 'Sheet1' - sheet2 = 'Sheet2' + xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx") + sheet1 = "Sheet1" + sheet2 = "Sheet2" with pd.ExcelWriter(xlsx_file_path) as excel_writer: df1.to_excel(excel_writer, sheet_name=sheet1, index=False) df2.to_excel(excel_writer, sheet_name=sheet2, index=False) @@ -54,9 +50,9 @@ def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): # read from xlsx file with sheets with pd.ExcelFile(xlsx_file_path) as excel_file: df1 = pd.read_excel(excel_file, sheet_name=sheet1) - assert df1[df1.columns[0]].to_list() == data1['col1'] - assert df1[df1.columns[1]].to_list() == data1['col2'] + assert df1[df1.columns[0]].to_list() == data1["col1"] + assert df1[df1.columns[1]].to_list() == data1["col2"] df2 = pd.read_excel(excel_file, sheet_name=sheet2) - assert df2[df2.columns[0]].to_list() == data2['col1'] - assert df2[df2.columns[1]].to_list() == data2['col2'] + assert df2[df2.columns[0]].to_list() == data2["col1"] + assert df2[df2.columns[1]].to_list() == data2["col2"] diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py index a979b77d70a285..2dc51252f00e72 100644 --- a/api/tests/unit_tests/libs/test_rsa.py +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -15,7 +15,7 @@ def test_gmpy2_pkcs10aep_cipher() -> None: private_rsa_key = RSA.import_key(private_key) private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key) - raw_text = 'raw_text' + raw_text = "raw_text" raw_text_bytes = raw_text.encode() # RSA encryption by public key and decryption by private key diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py index 75a534412673b1..b9aee4af5f31c7 100644 --- a/api/tests/unit_tests/libs/test_yarl.py +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -3,21 +3,21 @@ def test_yarl_urls(): - expected_1 = 'https://dify.ai/api' - assert str(URL('https://dify.ai') / 'api') == expected_1 - assert str(URL('https://dify.ai/') / 'api') == expected_1 + expected_1 = "https://dify.ai/api" + assert str(URL("https://dify.ai") / "api") == expected_1 + assert str(URL("https://dify.ai/") / "api") == expected_1 - expected_2 = 'http://dify.ai:12345/api' - assert str(URL('http://dify.ai:12345') / 'api') == expected_2 - assert str(URL('http://dify.ai:12345/') / 'api') == expected_2 + expected_2 = "http://dify.ai:12345/api" + assert str(URL("http://dify.ai:12345") / "api") == expected_2 + assert str(URL("http://dify.ai:12345/") / "api") == expected_2 - expected_3 = 'https://dify.ai/api/v1' - assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3 - assert str(URL('https://dify.ai') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/api') / 'v1') == expected_3 - assert str(URL('https://dify.ai/api/') / 'v1') == expected_3 + expected_3 = "https://dify.ai/api/v1" + assert str(URL("https://dify.ai") / "api" / "v1") == expected_3 + assert str(URL("https://dify.ai") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/api") / "v1") == expected_3 + assert str(URL("https://dify.ai/api/") / "v1") == expected_3 with pytest.raises(ValueError) as e1: - str(URL('https://dify.ai') / '/api') + str(URL("https://dify.ai") / "/api") assert str(e1.value) == "Appending path '/api' starting from slash is forbidden" diff --git a/api/tests/unit_tests/models/test_account.py b/api/tests/unit_tests/models/test_account.py index 006b99fb7d0935..026912ffbed300 100644 --- a/api/tests/unit_tests/models/test_account.py +++ b/api/tests/unit_tests/models/test_account.py @@ -2,13 +2,13 @@ def test_account_is_privileged_role() -> None: - assert TenantAccountRole.ADMIN == 'admin' - assert TenantAccountRole.OWNER == 'owner' - assert TenantAccountRole.EDITOR == 'editor' - assert TenantAccountRole.NORMAL == 'normal' + assert TenantAccountRole.ADMIN == "admin" + assert TenantAccountRole.OWNER == "owner" + assert TenantAccountRole.EDITOR == "editor" + assert TenantAccountRole.NORMAL == "normal" assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN) assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR) - assert not TenantAccountRole.is_privileged_role('') + assert not TenantAccountRole.is_privileged_role("") diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index 9e16010d7ef5a4..7968347decbdda 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -7,19 +7,19 @@ def test_from_variable_and_to_variable(): variable = factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'name': 'name', - 'value_type': SegmentType.OBJECT, - 'value': { - 'key': { - 'key': 'value', + "id": str(uuid4()), + "name": "name", + "value_type": SegmentType.OBJECT, + "value": { + "key": { + "key": "value", } }, } ) conversation_variable = ConversationVariable.from_variable( - app_id='app_id', conversation_id='conversation_id', variable=variable + app_id="app_id", conversation_id="conversation_id", variable=variable ) assert conversation_variable.to_variable() == variable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index bea896b83a84fd..40483d7e3a3baa 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -8,30 +8,30 @@ def test_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance workflow = Workflow( - tenant_id='tenant_id', - app_id='app_id', - type='workflow', - version='draft', - graph='{}', - features='{}', - created_by='account_id', + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", environment_variables=[], conversation_variables=[], ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance variables = [variable1, variable2, variable3, variable4] @@ -42,30 +42,30 @@ def test_environment_variables(): def test_update_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance workflow = Workflow( - tenant_id='tenant_id', - app_id='app_id', - type='workflow', - version='draft', - graph='{}', - features='{}', - created_by='account_id', + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", environment_variables=[], conversation_variables=[], ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): variables = [variable1, variable2, variable3, variable4] @@ -76,28 +76,28 @@ def test_update_environment_variables(): # Update the name of variable3 and keep the value as it is variables[2] = variable3.model_copy( update={ - 'name': 'new name', - 'value': HIDDEN_VALUE, + "name": "new name", + "value": HIDDEN_VALUE, } ) workflow.environment_variables = variables - assert workflow.environment_variables[2].name == 'new name' + assert workflow.environment_variables[2].name == "new name" assert workflow.environment_variables[2].value == variable3.value def test_to_dict(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance workflow = Workflow( - tenant_id='tenant_id', - app_id='app_id', - type='workflow', - version='draft', - graph='{}', - features='{}', - created_by='account_id', + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", environment_variables=[], conversation_variables=[], ) @@ -105,19 +105,19 @@ def test_to_dict(): # Create some EnvironmentVariable instances with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance workflow.environment_variables = [ - SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}), - StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}), + SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}), + StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}), ] workflow_dict = workflow.to_dict() - assert workflow_dict['environment_variables'][0]['value'] == '' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "" + assert workflow_dict["environment_variables"][1]["value"] == "text" workflow_dict = workflow.to_dict(include_secret=True) - assert workflow_dict['environment_variables'][0]['value'] == 'secret' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "secret" + assert workflow_dict["environment_variables"][1]["value"] == "text" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index a45423bf3988d6..805d92dfc93c1c 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -83,18 +83,12 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -105,10 +99,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -153,18 +144,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -175,10 +160,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -207,37 +189,25 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["sys", "query"] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -251,37 +221,25 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -293,14 +251,12 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -308,7 +264,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -316,17 +272,17 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): @@ -337,14 +293,12 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -352,7 +306,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -360,17 +314,17 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): @@ -381,14 +335,12 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -396,12 +348,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ]) + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -409,18 +365,18 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], list) assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) template = prompt_template.advanced_chat_prompt_template.messages[0].text for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): @@ -431,14 +387,12 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -448,12 +402,9 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var prompt_type=PromptTemplateEntity.PromptType.ADVANCED, advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" - "Human: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) - ) + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -461,14 +412,14 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], dict) template = prompt_template.advanced_completion_prompt_template.prompt for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 1235e559c93ac9..29558a93c242a8 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -8,8 +8,9 @@ @pytest.fixture def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions.yaml").write_text( + dedent( + """\ - first - second # - commented @@ -17,57 +18,54 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: - 9999999999999 - forth - """)) + """ + ) + ) return str(tmp_path) @pytest.fixture def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions_all_commented.yaml").write_text( + dedent( + """\ # - commented1 # - commented2 - - - """)) + """ + ) + ) return str(tmp_path) def test_position_helper(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml') + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") assert len(position_map) == 4 assert position_map == { - 'first': 0, - 'second': 1, - 'third': 2, - 'forth': 3, + "first": 0, + "second": 1, + "third": 2, + "forth": 3, } def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): position_map = get_position_map( - folder_path=prepare_empty_commented_positions_yaml, - file_name='example_positions_all_commented.yaml') + folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml" + ) assert position_map == {} def test_excluded_position_data(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml' - ) - pin_list = ['forth', 'first'] + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] include_set = set() - exclude_set = {'9999999999999'} + exclude_set = {"9999999999999"} - position_map = pin_position_map( - original_position_map=position_map, - pin_list=pin_list - ) + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) data = [ "forth", @@ -90,22 +88,16 @@ def test_excluded_position_data(prepare_example_positions_yaml): ) # assert the result in the correct order - assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2'] + assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"] def test_included_position_data(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml' - ) - pin_list = ['forth', 'first'] - include_set = {'forth', 'first'} + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = {"forth", "first"} exclude_set = {} - position_map = pin_position_map( - original_position_map=position_map, - pin_list=pin_list - ) + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) data = [ "forth", @@ -128,4 +120,4 @@ def test_included_position_data(prepare_example_positions_yaml): ) # assert the result in the correct order - assert sorted_data == ['forth', 'first'] + assert sorted_data == ["forth", "first"] diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index c0452b4e4d803a..95b93651d57f80 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -5,17 +5,18 @@ from core.tools.utils.yaml_utils import load_yaml_file -EXAMPLE_YAML_FILE = 'example_yaml.yaml' -INVALID_YAML_FILE = 'invalid_yaml.yaml' -NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' +EXAMPLE_YAML_FILE = "example_yaml.yaml" +INVALID_YAML_FILE = "invalid_yaml.yaml" +NON_EXISTING_YAML_FILE = "non_existing_file.yaml" @pytest.fixture def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -26,7 +27,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: - Java - C++ empty_key: - """)) + """ + ) + ) return str(file_path) @@ -34,8 +37,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(INVALID_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -45,13 +49,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: - Python - Java - C++ - """)) + """ + ) + ) return str(file_path) def test_load_yaml_non_existing_file(): assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path='') == {} + assert load_yaml_file(file_path="") == {} with pytest.raises(FileNotFoundError): load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) @@ -60,12 +66,12 @@ def test_load_yaml_non_existing_file(): def test_load_valid_yaml_file(prepare_example_yaml_file): yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 - assert yaml_data['age'] == 30 - assert yaml_data['gender'] == 'male' - assert yaml_data['address']['city'] == 'Example City' - assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} - assert yaml_data.get('empty_key') is None - assert yaml_data.get('non_existed_key') is None + assert yaml_data["age"] == 30 + assert yaml_data["gender"] == "male" + assert yaml_data["address"]["city"] == "Example City" + assert set(yaml_data["languages"]) == {"Python", "Java", "C++"} + assert yaml_data.get("empty_key") is None + assert yaml_data.get("non_existed_key") is None def test_load_invalid_yaml_file(prepare_invalid_yaml_file):