diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 1bdfd835d1e2..fff480120337 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -52,7 +52,7 @@ class OpenAIWrapper: """A wrapper class for openai client.""" cache_path_root: str = ".cache" - extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"} + extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"} openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) total_usage_summary: Optional[Dict[str, Any]] = None actual_usage_summary: Optional[Dict[str, Any]] = None diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index 3927ee3691cc..66332e4f909f 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -356,6 +356,11 @@ def filter_config(config_list, filter_dict): filter_dict (dict): A dictionary representing the filter criteria, where each key is a field name to check within the configuration dictionaries, and the corresponding value is a list of acceptable values for that field. + If the configuration's field's value is not a list, then a match occurs + when it is found in the list of acceptable values. If the configuration's + field's value is a list, then a match occurs if there is a non-empty + intersection with the acceptable values. + Returns: list of dict: A list of configuration dictionaries that meet all the criteria specified @@ -368,6 +373,7 @@ def filter_config(config_list, filter_dict): {'model': 'gpt-3.5-turbo'}, {'model': 'gpt-4'}, {'model': 'gpt-3.5-turbo', 'api_type': 'azure'}, + {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}, ] # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model @@ -382,6 +388,19 @@ def filter_config(config_list, filter_dict): # The resulting `filtered_configs` will be: # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}] + + + # Define a filter to select a given tag + filter_criteria = { + 'tags': ['gpt35_turbo'], + } + + # Apply the filter to the configuration list + filtered_configs = filter_config(configs, filter_criteria) + + # The resulting `filtered_configs` will be: + # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}] + ``` Note: @@ -391,9 +410,18 @@ def filter_config(config_list, filter_dict): - If the list of acceptable values for a key in `filter_dict` includes None, then configuration dictionaries that do not have that key will also be considered a match. """ + + def _satisfies(config_value, acceptable_values): + if isinstance(config_value, list): + return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection + else: + return config_value in acceptable_values + if filter_dict: config_list = [ - config for config in config_list if all(config.get(key) in value for key, value in filter_dict.items()) + config + for config in config_list + if all(_satisfies(config.get(key), value) for key, value in filter_dict.items()) ] return config_list diff --git a/test/oai/test_utils.py b/test/oai/test_utils.py index ab8d2544f717..57a70c0ffee3 100644 --- a/test/oai/test_utils.py +++ b/test/oai/test_utils.py @@ -8,7 +8,7 @@ import pytest import autogen # noqa: E402 -from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION +from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config # Example environment variables ENV_VARS = { @@ -48,6 +48,7 @@ }, { "model": "gpt-35-turbo-v0301", + "tags": ["gpt-3.5-turbo", "gpt35_turbo"], "api_key": "111113fc7e8a46419bfac511bb301111", "base_url": "https://1111.openai.azure.com", "api_type": "azure", @@ -342,5 +343,32 @@ def test_get_config_list(): assert len(config_list_with_empty_key) == 2, "The config_list should exclude configurations with empty api_keys." +def test_tags(): + config_list = json.loads(JSON_SAMPLE) + + target_list = filter_config(config_list, {"model": ["gpt-35-turbo-v0301"]}) + assert len(target_list) == 1 + + list_1 = filter_config(config_list, {"tags": ["gpt35_turbo"]}) + assert len(list_1) == 1 + assert list_1[0] == target_list[0] + + list_2 = filter_config(config_list, {"tags": ["gpt-3.5-turbo"]}) + assert len(list_2) == 1 + assert list_2[0] == target_list[0] + + list_3 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "gpt35_turbo"]}) + assert len(list_3) == 1 + assert list_3[0] == target_list[0] + + # Will still match because there's a non-empty intersection + list_4 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "does_not_exist"]}) + assert len(list_4) == 1 + assert list_4[0] == target_list[0] + + list_5 = filter_config(config_list, {"tags": ["does_not_exist"]}) + assert len(list_5) == 0 + + if __name__ == "__main__": pytest.main()