Skip to content

Commit

Permalink
Merge branch 'main' into litellm_filter_invalid_params
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia authored May 22, 2024
2 parents 0001b32 + 9081956 commit beb6170
Show file tree
Hide file tree
Showing 21 changed files with 649 additions and 343 deletions.
13 changes: 6 additions & 7 deletions docs/my-website/docs/completion/input.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,24 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea

| Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--|
|Anthropic|||||| | | | | | | | | |||
|Anthropic|||||| | | | | | | |||||
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | ✅ | ✅ | ✅ | | | ✅
|OpenAI||||||||||||||||||||
|Azure OpenAI||||||||||||||||| | ||
|Replicate |||||| | | | | |
|Anyscale |||||
|Cohere||||||||| | |
|Huggingface||||||| | | | |
|Openrouter|||||||||||
|AI21||||||||| | |
|VertexAI||| || | | | | | |
|Bedrock|||||| | | | | |
|Openrouter||||||||||| | | | || | | | |
|AI21||||||||| | |
|VertexAI||| || | | | | | | | | | || | |
|Bedrock|||||| | | | | | | | | | ✅ (for anthropic) | |
|Sagemaker||||||| | | | |
|TogetherAI|||||| | | | | ||
|AlephAlpha||||||| | | | |
|Palm||||||| | | | |
|NLP Cloud|||||| | | | | |
|Petals||| || | | | | | |
|Ollama|||||| | || | |
|Ollama|||||| | || | | | || | |

:::note

Expand Down
1 change: 1 addition & 0 deletions docs/my-website/docs/enterprise.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This covers:
-**Custom SLAs**
-[**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
-[**JWT-Auth**](../docs/proxy/token_auth.md)
-[**Invite Team Members to access `/spend` Routes**](../docs/proxy/cost_tracking#allowing-non-proxy-admins-to-access-spend-endpoints)


## [COMING SOON] AWS Marketplace Support
Expand Down
30 changes: 30 additions & 0 deletions docs/my-website/docs/proxy/cost_tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,36 @@ Output from script

</Tabs>

## Allowing Non-Proxy Admins to access `/spend` endpoints

Use this when you want non-proxy admins to access `/spend` endpoints

:::info

Schedule a [meeting with us to get your Enterprise License](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)

:::

### Create Key
Create Key with with `permissions={"get_spend_routes": true}`
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"permissions": {"get_spend_routes": true}
}'
```

### Use generated key on `/spend` endpoints

Access spend Routes with newly generate keys
```shell
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30' \
-H 'Authorization: Bearer sk-H16BKvrSNConSsBYLGc_7A'
```



## Reset Team, API Key Spend - MASTER KEY ONLY

Expand Down
1 change: 0 additions & 1 deletion litellm/integrations/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,6 @@ def _log_langfuse_v2(
if "cache_hit" in kwargs:
if kwargs["cache_hit"] is None:
kwargs["cache_hit"] = False
tags.append(f"cache_hit:{kwargs['cache_hit']}")
clean_metadata["cache_hit"] = kwargs["cache_hit"]
if existing_trace_id is None:
trace_params.update({"tags": tags})
Expand Down
12 changes: 12 additions & 0 deletions litellm/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice


class AnthropicConstants(Enum):
Expand Down Expand Up @@ -102,6 +103,17 @@ def map_openai_params(self, non_default_params: dict, optional_params: dict):
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}

if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_config(cls):

def validate_environment(api_key):
headers = {
"Request-Source":"unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/cohere_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def get_config(cls):

def validate_environment(api_key):
headers = {
"Request-Source":"unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
Expand Down
2 changes: 1 addition & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def completion(
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class LiteLLMRoutes(enum.Enum):
"/global/spend/end_users",
"/global/spend/models",
"/global/predict/spend/logs",
"/global/spend/report",
]

public_routes: List = [
Expand Down
17 changes: 16 additions & 1 deletion litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,13 @@ async def user_api_key_auth(
_has_user_setup_sso()
and route in LiteLLMRoutes.sso_only_routes.value
):
pass
elif (
route in LiteLLMRoutes.global_spend_tracking_routes.value
and getattr(valid_token, "permissions", None) is not None
and "get_spend_routes" in getattr(valid_token, "permissions", None)
):

pass
else:
user_role = "unknown"
Expand Down Expand Up @@ -2967,7 +2974,7 @@ async def generate_key_helper_fn(
organization_id: Optional[str] = None,
table_name: Optional[Literal["key", "user"]] = None,
):
global prisma_client, custom_db_client, user_api_key_cache, litellm_proxy_admin_name
global prisma_client, custom_db_client, user_api_key_cache, litellm_proxy_admin_name, premium_user

if prisma_client is None and custom_db_client is None:
raise Exception(
Expand Down Expand Up @@ -3062,6 +3069,14 @@ async def generate_key_helper_fn(
if isinstance(saved_token["metadata"], str):
saved_token["metadata"] = json.loads(saved_token["metadata"])
if isinstance(saved_token["permissions"], str):
if (
"get_spend_routes" in saved_token["permissions"]
and premium_user != True
):
raise Exception(
"get_spend_routes permission is only available for LiteLLM Enterprise users"
)

saved_token["permissions"] = json.loads(saved_token["permissions"])
if isinstance(saved_token["model_max_budget"], str):
saved_token["model_max_budget"] = json.loads(
Expand Down
83 changes: 79 additions & 4 deletions litellm/router_strategy/lowest_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,29 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time):
precise_minute = f"{current_date}-{current_hour}-{current_minute}"

response_ms: timedelta = end_time - start_time
time_to_first_token_response_time: Optional[timedelta] = None

if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)

final_value = response_ms
time_to_first_token: Optional[float] = None
total_tokens = 0

if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(response_ms.total_seconds() / completion_tokens)

if time_to_first_token_response_time is not None:
time_to_first_token = float(
time_to_first_token_response_time.total_seconds()
/ completion_tokens
)

# ------------
# Update usage
# ------------
Expand All @@ -112,6 +126,24 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time):
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]

## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]

if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}

Expand Down Expand Up @@ -226,6 +258,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
{model_group}_map: {
id: {
"latency": [..]
"time_to_first_token": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
Expand All @@ -239,15 +272,27 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
precise_minute = f"{current_date}-{current_hour}-{current_minute}"

response_ms: timedelta = end_time - start_time
time_to_first_token_response_time: Optional[timedelta] = None
if kwargs.get("stream", None) is not None and kwargs["stream"] == True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)

final_value = response_ms
total_tokens = 0
time_to_first_token: Optional[float] = None

if isinstance(response_obj, ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
total_tokens = response_obj.usage.total_tokens
final_value = float(response_ms.total_seconds() / completion_tokens)

if time_to_first_token_response_time is not None:
time_to_first_token = float(
time_to_first_token_response_time.total_seconds()
/ completion_tokens
)
# ------------
# Update usage
# ------------
Expand All @@ -268,6 +313,24 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]

## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]

if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}

Expand Down Expand Up @@ -370,14 +433,25 @@ def get_available_deployments(
or float("inf")
)
item_latency = item_map.get("latency", [])
item_ttft_latency = item_map.get("time_to_first_token", [])
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)

# get average latency
# get average latency or average ttft (depending on streaming/non-streaming)
total: float = 0.0
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
if (
request_kwargs is not None
and request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] == True
and len(item_ttft_latency) > 0
):
for _call_latency in item_ttft_latency:
if isinstance(_call_latency, float):
total += _call_latency
else:
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_latency)

# -------------- #
Expand Down Expand Up @@ -413,6 +487,7 @@ def get_available_deployments(

# Find deployments within buffer of lowest latency
buffer = self.routing_args.lowest_latency_buffer * lowest_latency

valid_deployments = [
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
]
Expand Down
2 changes: 1 addition & 1 deletion litellm/tests/test_alangfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def test_langfuse_logging_function_calling():
# test_langfuse_logging_function_calling()


@pytest.mark.skip(reason="skip b/c langfuse changed their api")
@pytest.mark.skip(reason="Need to address this on main")
def test_aaalangfuse_existing_trace_id():
"""
When existing trace id is passed, don't set trace params -> prevents overwriting the trace
Expand Down
7 changes: 5 additions & 2 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,12 @@ def test_completion_claude_3_function_call():
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice={"type": "tool", "name": "get_weather"},
extra_headers={"anthropic-beta": "tools-2024-05-16"},
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
},
)

# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
Expand Down
25 changes: 12 additions & 13 deletions litellm/tests/test_custom_callback_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ class CompletionCustomHandler(
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
self.states: List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
] = []

Expand Down Expand Up @@ -269,6 +267,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["completion_start_time"], datetime)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
Expand Down
Loading

0 comments on commit beb6170

Please sign in to comment.