Skip to content

Commit

Permalink
Merge pull request BerriAI#3770 from BerriAI/litellm_filter_invalid_p…
Browse files Browse the repository at this point in the history
…arams

feat(router.py): filter out deployments which don't support request params w/ 'pre_call_checks=True'
  • Loading branch information
krrishdholakia authored May 22, 2024
2 parents 9081956 + beb6170 commit 5d7d638
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 55 deletions.
63 changes: 40 additions & 23 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: di
self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
Expand Down Expand Up @@ -3207,7 +3207,7 @@ def _pre_call_checks(
model: str,
healthy_deployments: List,
messages: List[Dict[str, str]],
allowed_model_region: Optional[Literal["eu"]] = None,
request_kwargs: Optional[dict] = None,
):
"""
Filter out model in model group, if:
Expand Down Expand Up @@ -3299,7 +3299,11 @@ def _pre_call_checks(
continue

## REGION CHECK ##
if allowed_model_region is not None:
if (
request_kwargs is not None
and request_kwargs.get("allowed_model_region") is not None
and request_kwargs["allowed_model_region"] == "eu"
):
if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str
):
Expand All @@ -3313,13 +3317,37 @@ def _pre_call_checks(
else:
verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, allowed_model_region
model_id, request_kwargs.get("allowed_model_region")
)
)
# filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx)
continue

## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_object' param
if request_kwargs is not None and litellm.drop_params == False:
# get supported params
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
)

supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)

if supported_openai_params is None:
continue
else:
# check the non-default openai params in request kwargs
non_default_params = litellm.utils.get_non_default_params(
passed_params=request_kwargs
)
# check if all params are supported
for k, v in non_default_params.items():
if k not in supported_openai_params:
# if not -> invalid model
invalid_model_indices.append(idx)

if len(invalid_model_indices) == len(_returned_deployments):
"""
- no healthy deployments available b/c context window checks or rate limit error
Expand Down Expand Up @@ -3469,25 +3497,14 @@ async def async_get_available_deployment(
if request_kwargs is not None
else None
)

if self.enable_pre_call_checks and messages is not None:
if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
allowed_model_region=_allowed_model_region,
)
else:
verbose_router_logger.debug(
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
_allowed_model_region
)
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
)

if len(healthy_deployments) == 0:
if _allowed_model_region is None:
Expand Down
40 changes: 39 additions & 1 deletion litellm/tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,44 @@ def test_router_context_window_check_pre_call_check_out_group():
pytest.fail(f"Got unexpected exception on router! - {str(e)}")


def test_filter_invalid_params_pre_call_check():
"""
- gpt-3.5-turbo supports 'response_object'
- gpt-3.5-turbo-16k doesn't support 'response_object'
run pre-call check -> assert returned list doesn't include gpt-3.5-turbo-16k
"""
try:
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]

router = Router(model_list=model_list, set_verbose=True, enable_pre_call_checks=True, num_retries=0) # type: ignore

filtered_deployments = router._pre_call_checks(
model="gpt-3.5-turbo",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
request_kwargs={"response_format": {"type": "json_object"}},
)
assert len(filtered_deployments) == 1
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {str(e)}")


@pytest.mark.parametrize("allowed_model_region", ["eu", None])
def test_router_region_pre_call_check(allowed_model_region):
"""
Expand Down Expand Up @@ -724,7 +762,7 @@ def test_router_region_pre_call_check(allowed_model_region):
model="gpt-3.5-turbo",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey!"}],
allowed_model_region=allowed_model_region,
request_kwargs={"allowed_model_region": allowed_model_region},
)

if allowed_model_region is None:
Expand Down
66 changes: 37 additions & 29 deletions litellm/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,35 +1730,43 @@ def test_openai_stream_options_call():

def test_openai_stream_options_call_text_completion():
litellm.set_verbose = False
response = litellm.text_completion(
model="gpt-3.5-turbo-instruct",
prompt="say GM - we're going to make it ",
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
usage = None
chunks = []
for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)

last_chunk = chunks[-1]
print("last chunk: ", last_chunk)

"""
Assert that:
- Last Chunk includes Usage
- All chunks prior to last chunk have usage=None
"""

assert last_chunk.usage is not None
assert last_chunk.usage.total_tokens > 0
assert last_chunk.usage.prompt_tokens > 0
assert last_chunk.usage.completion_tokens > 0

# assert all non last chunks have usage=None
assert all(chunk.usage is None for chunk in chunks[:-1])
for idx in range(3):
try:
response = litellm.text_completion(
model="gpt-3.5-turbo-instruct",
prompt="say GM - we're going to make it ",
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
usage = None
chunks = []
for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)

last_chunk = chunks[-1]
print("last chunk: ", last_chunk)

"""
Assert that:
- Last Chunk includes Usage
- All chunks prior to last chunk have usage=None
"""

assert last_chunk.usage is not None
assert last_chunk.usage.total_tokens > 0
assert last_chunk.usage.prompt_tokens > 0
assert last_chunk.usage.completion_tokens > 0

# assert all non last chunks have usage=None
assert all(chunk.usage is None for chunk in chunks[:-1])
break
except Exception as e:
if idx < 2:
pass
else:
raise e


def test_openai_text_completion_call():
Expand Down
51 changes: 49 additions & 2 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5811,7 +5811,7 @@ def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
]:
supported_params += [
supported_params += [ # type: ignore
"functions",
"function_call",
"tools",
Expand Down Expand Up @@ -6061,6 +6061,47 @@ def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
return optional_params


def get_non_default_params(passed_params: dict) -> dict:
default_params = {
"functions": None,
"function_call": None,
"temperature": None,
"top_p": None,
"n": None,
"stream": None,
"stream_options": None,
"stop": None,
"max_tokens": None,
"presence_penalty": None,
"frequency_penalty": None,
"logit_bias": None,
"user": None,
"model": None,
"custom_llm_provider": "",
"response_format": None,
"seed": None,
"tools": None,
"tool_choice": None,
"max_retries": None,
"logprobs": None,
"top_logprobs": None,
"extra_headers": None,
}
# filter out those parameters that were passed with non-default values
non_default_params = {
k: v
for k, v in passed_params.items()
if (
k != "model"
and k != "custom_llm_provider"
and k in default_params
and v != default_params[k]
)
}

return non_default_params


def calculate_max_parallel_requests(
max_parallel_requests: Optional[int],
rpm: Optional[int],
Expand Down Expand Up @@ -6287,14 +6328,18 @@ def get_first_chars_messages(kwargs: dict) -> str:
return ""


def get_supported_openai_params(model: str, custom_llm_provider: str):
def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optional[list]:
"""
Returns the supported openai params for a given model + provider

Example:
```
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
```

Returns:
- List if custom_llm_provider is mapped
- None if unmapped
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
Expand Down Expand Up @@ -6534,6 +6579,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()

return None


def get_formatted_prompt(
data: dict,
Expand Down

0 comments on commit 5d7d638

Please sign in to comment.