From aa97ef04e8fcb34d2662fdf99a3ac29dc823e34a Mon Sep 17 00:00:00 2001 From: walter from vm Date: Thu, 29 Aug 2024 02:21:50 +0000 Subject: [PATCH] format --- .../oci/llm/meta.llama-3-70b-instruct.yaml | 2 +- .../model_providers/oci/oci.yaml | 2 +- .../oci/text_embedding/_position.yaml | 2 +- .../model_runtime/oci/test_llm.py | 110 ++++++------------ .../model_runtime/oci/test_provider.py | 8 +- .../model_runtime/oci/test_text_embedding.py | 43 +++---- 6 files changed, 59 insertions(+), 108 deletions(-) diff --git a/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml index 8a031ba3faa1d4..dd5be107c07570 100644 --- a/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml @@ -48,4 +48,4 @@ pricing: input: '0.015' output: '0.015' unit: '0.0001' - currency: USD \ No newline at end of file + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/oci.yaml b/api/core/model_runtime/model_providers/oci/oci.yaml index 977f3ffeebccb6..f2f23e18f12073 100644 --- a/api/core/model_runtime/model_providers/oci/oci.yaml +++ b/api/core/model_runtime/model_providers/oci/oci.yaml @@ -39,4 +39,4 @@ provider_credential_schema: required: true placeholder: zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8'))) - en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')) ) \ No newline at end of file + en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8'))) diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml index 56fd741e4eeb70..149f1e3797850f 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml +++ b/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml @@ -2,4 +2,4 @@ - cohere.embed-english-light-v3.0 - cohere.embed-english-v3.0 - cohere.embed-multilingual-light-v3.0 -- cohere.embed-multilingual-v3.0 \ No newline at end of file +- cohere.embed-multilingual-v3.0 diff --git a/api/tests/integration_tests/model_runtime/oci/test_llm.py b/api/tests/integration_tests/model_runtime/oci/test_llm.py index be56789ace9826..531f26a32e657c 100644 --- a/api/tests/integration_tests/model_runtime/oci/test_llm.py +++ b/api/tests/integration_tests/model_runtime/oci/test_llm.py @@ -20,19 +20,16 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='cohere.command-r-plus', - credentials={ - 'oci_config_content': 'invalid_key', - 'oci_key_content': 'invalid_key' - } + model="cohere.command-r-plus", + credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"}, ) model.validate_credentials( - model='cohere.command-r-plus', + model="cohere.command-r-plus", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') - } + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, ) @@ -40,23 +37,16 @@ def test_invoke_model(): model = OCILargeLanguageModel() response = model.invoke( - model='cohere.command-r-plus', + model="cohere.command-r-plus", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -67,23 +57,15 @@ def test_invoke_stream_model(): model = OCILargeLanguageModel() response = model.invoke( - model='meta.llama-3-70b-instruct', + model="meta.llama-3-70b-instruct", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -99,45 +81,29 @@ def test_invoke_model_with_function(): model = OCILargeLanguageModel() response = model.invoke( - model='cohere.command-r-plus', + model="cohere.command-r-plus", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=False, user="abc-123", tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, LLMResult) @@ -148,19 +114,17 @@ def test_get_num_tokens(): model = OCILargeLanguageModel() num_tokens = model.get_num_tokens( - model='cohere.command-r-plus', + model="cohere.command-r-plus", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/oci/test_provider.py b/api/tests/integration_tests/model_runtime/oci/test_provider.py index 657585227d57e0..2c7107c7ccfe45 100644 --- a/api/tests/integration_tests/model_runtime/oci/test_provider.py +++ b/api/tests/integration_tests/model_runtime/oci/test_provider.py @@ -10,13 +10,11 @@ def test_validate_provider_credentials(): provider = OCIGENAIProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), } ) diff --git a/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py b/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py index 8487fc1eb23283..032c5c681a7aeb 100644 --- a/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/oci/test_text_embedding.py @@ -12,19 +12,16 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='cohere.embed-multilingual-v3.0', - credentials={ - 'oci_config_content': 'invalid_key', - 'oci_key_content': 'invalid_key' - } + model="cohere.embed-multilingual-v3.0", + credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"}, ) model.validate_credentials( - model='cohere.embed-multilingual-v3.0', + model="cohere.embed-multilingual-v3.0", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') - } + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), + }, ) @@ -32,38 +29,30 @@ def test_invoke_model(): model = OCITextEmbeddingModel() result = model.invoke( - model='cohere.embed-multilingual-v3.0', + model="cohere.embed-multilingual-v3.0", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 4 - #assert result.usage.total_tokens == 811 + # assert result.usage.total_tokens == 811 def test_get_num_tokens(): model = OCITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='cohere.embed-multilingual-v3.0', + model="cohere.embed-multilingual-v3.0", credentials={ - 'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'), - 'oci_key_content': os.environ.get('OCI_KEY_CONTENT') + "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"), + "oci_key_content": os.environ.get("OCI_KEY_CONTENT"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2