Skip to content

Commit

Permalink
Added the ability to add tags to the OAI_CONFIG_LIST, and filter (mic…
Browse files Browse the repository at this point in the history
…rosoft#1226)

* Added the ability to add tags to the OAI_CONFIG_LIST, and filter on them.

* Update openai_utils.py

Co-authored-by: Chi Wang <[email protected]>

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
afourney and sonichi authored Jan 15, 2024
1 parent 63a35e7 commit e6325a4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
2 changes: 1 addition & 1 deletion autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class OpenAIWrapper:
"""A wrapper class for openai client."""

cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None
Expand Down
30 changes: 29 additions & 1 deletion autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ def filter_config(config_list, filter_dict):
filter_dict (dict): A dictionary representing the filter criteria, where each key is a
field name to check within the configuration dictionaries, and the
corresponding value is a list of acceptable values for that field.
If the configuration's field's value is not a list, then a match occurs
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
Expand All @@ -368,6 +373,7 @@ def filter_config(config_list, filter_dict):
{'model': 'gpt-3.5-turbo'},
{'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
Expand All @@ -382,6 +388,19 @@ def filter_config(config_list, filter_dict):
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
'tags': ['gpt35_turbo'],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
Note:
Expand All @@ -391,9 +410,18 @@ def filter_config(config_list, filter_dict):
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
"""

def _satisfies(config_value, acceptable_values):
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values

if filter_dict:
config_list = [
config for config in config_list if all(config.get(key) in value for key, value in filter_dict.items())
config
for config in config_list
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
]
return config_list

Expand Down
30 changes: 29 additions & 1 deletion test/oai/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import autogen # noqa: E402
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config

# Example environment variables
ENV_VARS = {
Expand Down Expand Up @@ -48,6 +48,7 @@
},
{
"model": "gpt-35-turbo-v0301",
"tags": ["gpt-3.5-turbo", "gpt35_turbo"],
"api_key": "111113fc7e8a46419bfac511bb301111",
"base_url": "https://1111.openai.azure.com",
"api_type": "azure",
Expand Down Expand Up @@ -342,5 +343,32 @@ def test_get_config_list():
assert len(config_list_with_empty_key) == 2, "The config_list should exclude configurations with empty api_keys."


def test_tags():
config_list = json.loads(JSON_SAMPLE)

target_list = filter_config(config_list, {"model": ["gpt-35-turbo-v0301"]})
assert len(target_list) == 1

list_1 = filter_config(config_list, {"tags": ["gpt35_turbo"]})
assert len(list_1) == 1
assert list_1[0] == target_list[0]

list_2 = filter_config(config_list, {"tags": ["gpt-3.5-turbo"]})
assert len(list_2) == 1
assert list_2[0] == target_list[0]

list_3 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "gpt35_turbo"]})
assert len(list_3) == 1
assert list_3[0] == target_list[0]

# Will still match because there's a non-empty intersection
list_4 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "does_not_exist"]})
assert len(list_4) == 1
assert list_4[0] == target_list[0]

list_5 = filter_config(config_list, {"tags": ["does_not_exist"]})
assert len(list_5) == 0


if __name__ == "__main__":
pytest.main()

0 comments on commit e6325a4

Please sign in to comment.