Skip to content

Commit

Permalink
updated conftest
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanehachcham committed Oct 26, 2023
1 parent b211085 commit 6932f87
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 31 deletions.
10 changes: 6 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from mentor_mingle.llm_handler import ChatHandler
from mentor_mingle.persona.mentor import Mentor

if __name__ == "__main__":
mentor = Mentor()
handler = ChatHandler(persona=mentor)

handler = ChatHandler()
while True:
user_prompt = input("\nUser: ")
for res in handler.stream_chat(user_prompt):
end_char = "\n" if "." in res else ""
print(res, end=end_char, flush=True)

handler.last_chat = {
"query": user_prompt,
"assistant": "",
}
65 changes: 53 additions & 12 deletions mentor_mingle/llm_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import os
import redis
from pathlib import Path
from typing import Generator
from typing import Generator, Optional, Union, Any

import openai
from dotenv import load_dotenv

from .config import Config
from .persona.base import BasePersona
from .helpers.cache import Cache
from .persona.mentor import Mentor
from .utils import find_closest

load_dotenv()
Expand All @@ -17,17 +19,44 @@
class ChatHandler:
"""Handler for chat with GPT-3"""

_memory: Cache = None

def __init__(
self,
persona: BasePersona,
cache_client: Union[redis.Redis, Any] = None,
):
"""Initialize the chat handler"""
openai.api_key = os.getenv("OPENAI_KEY")
self.agent = persona
openai.api_key = os.environ.get("OPENAI_KEY")
self.agent = Mentor()

# Load config
self.model = Config.from_toml(Path(find_closest("config.toml"))).models.gpt3

# Initialize memory
if self._memory is None:
self._memory = Cache(cache_client)

@property
def memory(self) -> Cache:
"""Get the memory"""
return self._memory

@property
def last_chat(self) -> Optional[dict]:
"""Get the memory"""
return self._memory.get_map("last_chat")

@last_chat.setter
def last_chat(self, value: dict) -> None:
"""Set the memory"""
self._memory.set_map(
"last_chat",
{
"query": value.get(b"query", ""),
"assistant": value.get(b"assistant", ""),
},
)

def stream_chat(self, user_prompt: str) -> Generator[str, None, None]:
"""
Stream a chat with GPT-3
Expand All @@ -38,15 +67,27 @@ def stream_chat(self, user_prompt: str) -> Generator[str, None, None]:
Returns:
None
"""
completion = openai.ChatCompletion.create(
model=self.model.name,
messages=[
{"role": "system", "content": self.agent.persona},
intro_session = [
{"role": "system", "content": self.agent.persona},
{
"role": "user",
"content": f"User: {user_prompt}",
},
{
"role": "system",
"content": self.agent.answer_format,
},
]
if self.last_chat:
intro_session += [
{
"role": "user",
"content": f"User: {user_prompt}" f"\n{self.agent.answer_format}",
},
],
"content": f"{self.last_chat.get(b'query', '')}",
}
]
completion = openai.ChatCompletion.create(
model=self.model.name,
messages=intro_session,
**self.model.config.model_dump(),
)
for chunk in completion:
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

import pytest
import fakeredis

@pytest.fixture(scope="session")
def fake_redis() -> fakeredis.FakeStrictRedis:
"""
Create a fake Redis client.
Returns: FakeRedis client
"""
return fakeredis.FakeStrictRedis()
11 changes: 0 additions & 11 deletions tests/helpers/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import fakeredis
import pytest

from mentor_mingle.helpers.cache import Cache

Expand All @@ -9,15 +7,6 @@ class TestCache:
Test the Cache class.
"""

@pytest.fixture
def fake_redis(self) -> fakeredis.FakeStrictRedis:
"""
Create a fake Redis client.
Returns: FakeRedis client
"""
return fakeredis.FakeStrictRedis()

def test_get(self, fake_redis):
"""
Test the get method of the Cache class.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_llm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ class TestChatHandler:
Test the ChatHandler class.
"""

def test_chat_handler_init(self):
def test_chat_handler_init(self, fake_redis):
"""
Test the stream_chat method of the ChatHandler class.
"""
handler = ChatHandler(Mentor())
handler = ChatHandler(cache_client=fake_redis)
assert isinstance(handler.model, Gpt)
assert isinstance(handler.agent, Mentor)
assert isinstance(handler.model.config, CFGGpt)
Expand All @@ -40,7 +40,7 @@ def mock_response_generator(self, **kwargs) -> Generator[OpenAIObject, None, Non
mock_obj.choices = [choice_mock]
yield mock_obj

def test_stream_chat(self, mocker: MagicMock):
def test_stream_chat(self, mocker: MagicMock, fake_redis):
"""
Test the stream_chat method of the ChatHandler class.
Expand All @@ -51,7 +51,7 @@ def test_stream_chat(self, mocker: MagicMock):
None
"""
# Create a mock instance of your class
llm = ChatHandler(Mentor())
llm = ChatHandler(cache_client=fake_redis)

# Patch the API call to return mock_response
mocker.patch.object(openai.ChatCompletion, "create", side_effect=self.mock_response_generator)
Expand Down

0 comments on commit 6932f87

Please sign in to comment.