Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Walter-jin committed Aug 29, 2024
1 parent 9f9c4f2 commit abc7e81
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ pricing:
input: '0.015'
output: '0.015'
unit: '0.0001'
currency: USD
currency: USD
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/oci/oci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ provider_credential_schema:
required: true
placeholder:
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')) )
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
- cohere.embed-english-light-v3.0
- cohere.embed-english-v3.0
- cohere.embed-multilingual-light-v3.0
- cohere.embed-multilingual-v3.0
- cohere.embed-multilingual-v3.0
110 changes: 37 additions & 73 deletions api/tests/integration_tests/model_runtime/oci/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,33 @@ def test_validate_credentials():

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='cohere.command-r-plus',
credentials={
'oci_config_content': 'invalid_key',
'oci_key_content': 'invalid_key'
}
model="cohere.command-r-plus",
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
)

model.validate_credentials(
model='cohere.command-r-plus',
model="cohere.command-r-plus",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
}
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
)


def test_invoke_model():
model = OCILargeLanguageModel()

response = model.invoke(
model='cohere.command-r-plus',
model="cohere.command-r-plus",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
stop=['How'],
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)

assert isinstance(response, LLMResult)
Expand All @@ -67,23 +57,15 @@ def test_invoke_stream_model():
model = OCILargeLanguageModel()

response = model.invoke(
model='meta.llama-3-70b-instruct',
model="meta.llama-3-70b-instruct",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123"
user="abc-123",
)

assert isinstance(response, Generator)
Expand All @@ -99,45 +81,29 @@ def test_invoke_model_with_function():
model = OCILargeLanguageModel()

response = model.invoke(
model='cohere.command-r-plus',
model="cohere.command-r-plus",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=False,
user="abc-123",
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)

assert isinstance(response, LLMResult)
Expand All @@ -148,19 +114,17 @@ def test_get_num_tokens():
model = OCILargeLanguageModel()

num_tokens = model.get_num_tokens(
model='cohere.command-r-plus',
model="cohere.command-r-plus",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)

assert num_tokens == 18
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ def test_validate_provider_credentials():
provider = OCIGENAIProvider()

with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})

provider.validate_provider_credentials(
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,47 @@ def test_validate_credentials():

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='cohere.embed-multilingual-v3.0',
credentials={
'oci_config_content': 'invalid_key',
'oci_key_content': 'invalid_key'
}
model="cohere.embed-multilingual-v3.0",
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
)

model.validate_credentials(
model='cohere.embed-multilingual-v3.0',
model="cohere.embed-multilingual-v3.0",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
}
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
)


def test_invoke_model():
model = OCITextEmbeddingModel()

result = model.invoke(
model='cohere.embed-multilingual-v3.0',
model="cohere.embed-multilingual-v3.0",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)

assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 4
#assert result.usage.total_tokens == 811
# assert result.usage.total_tokens == 811


def test_get_num_tokens():
model = OCITextEmbeddingModel()

num_tokens = model.get_num_tokens(
model='cohere.embed-multilingual-v3.0',
model="cohere.embed-multilingual-v3.0",
credentials={
'oci_config_content': os.environ.get('OCI_CONFIG_CONTENT'),
'oci_key_content': os.environ.get('OCI_KEY_CONTENT')
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)

assert num_tokens == 2

0 comments on commit abc7e81

Please sign in to comment.