Skip to content

Commit

Permalink
feat: Add TTFT support in OpenAI chat generator (#8444)
Browse files Browse the repository at this point in the history
* feat: Add TTFT support in OpenAI generators

* pylint fixes

* correct disable

---------

Co-authored-by: Vladimir Blagojevic <[email protected]>
  • Loading branch information
LastRemote and vblagoje authored Oct 31, 2024
1 parent 294a67e commit 2595e68
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
14 changes: 12 additions & 2 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import json
import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream
Expand Down Expand Up @@ -63,7 +64,7 @@ class OpenAIChatGenerator:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-4o-mini",
Expand Down Expand Up @@ -222,11 +223,15 @@ def run(
raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = []
chunk = None
_first_token = True

# pylint: disable=not-an-iterable
for chunk in chat_completion:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
if _first_token:
_first_token = False
chunk_delta.meta["completion_start_time"] = datetime.now().isoformat()
chunks.append(chunk_delta)
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
Expand Down Expand Up @@ -280,7 +285,12 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa
payload["function"]["arguments"] += delta.arguments or ""
complete_response = ChatMessage.from_assistant(json.dumps(payloads))
else:
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
total_content = ""
total_meta = {}
for streaming_chunk in chunks:
total_content += streaming_chunk.content
total_meta.update(streaming_chunk.meta)
complete_response = ChatMessage.from_assistant(total_content, meta=total_meta)
complete_response.meta.update(
{
"model": chunk.model,
Expand Down
6 changes: 6 additions & 0 deletions releasenotes/notes/openai-ttft-42b1ad551b542930.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Add TTFT (Time-to-First-Token) support for OpenAI generators. This
captures the time taken to generate the first token from the model and
can be used to analyze the latency of the application.
11 changes: 10 additions & 1 deletion test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from unittest.mock import patch

import pytest
from openai import OpenAIError
Expand Down Expand Up @@ -219,7 +220,8 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk

def test_run_with_streaming_callback_in_run_method(self, chat_messages, mock_chat_completion_chunk):
@patch("haystack.components.generators.chat.openai.datetime")
def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk):
streaming_callback_called = False

def streaming_callback(chunk: StreamingChunk) -> None:
Expand All @@ -240,6 +242,13 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk

assert hasattr(response["replies"][0], "meta")
assert isinstance(response["replies"][0].meta, dict)
assert (
response["replies"][0].meta["completion_start_time"]
== mock_datetime.now.return_value.isoformat.return_value
)

def test_check_abnormal_completions(self, caplog):
caplog.set_level(logging.INFO)
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
Expand Down

0 comments on commit 2595e68

Please sign in to comment.