diff --git a/main.py b/main.py index ddf4076..ee1911c 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,14 @@ from mentor_mingle.llm_handler import ChatHandler -from mentor_mingle.persona.mentor import Mentor if __name__ == "__main__": - mentor = Mentor() - handler = ChatHandler(persona=mentor) - + handler = ChatHandler() while True: user_prompt = input("\nUser: ") for res in handler.stream_chat(user_prompt): end_char = "\n" if "." in res else "" print(res, end=end_char, flush=True) + + handler.last_chat = { + "query": user_prompt, + "assistant": "", + } diff --git a/mentor_mingle/llm_handler.py b/mentor_mingle/llm_handler.py index 1d010cc..44992f9 100644 --- a/mentor_mingle/llm_handler.py +++ b/mentor_mingle/llm_handler.py @@ -1,13 +1,15 @@ import logging import os +import redis from pathlib import Path -from typing import Generator +from typing import Generator, Optional, Union, Any import openai from dotenv import load_dotenv from .config import Config -from .persona.base import BasePersona +from .helpers.cache import Cache +from .persona.mentor import Mentor from .utils import find_closest load_dotenv() @@ -17,17 +19,44 @@ class ChatHandler: """Handler for chat with GPT-3""" + _memory: Cache = None + def __init__( self, - persona: BasePersona, + cache_client: Union[redis.Redis, Any] = None, ): """Initialize the chat handler""" - openai.api_key = os.getenv("OPENAI_KEY") - self.agent = persona + openai.api_key = os.environ.get("OPENAI_KEY") + self.agent = Mentor() # Load config self.model = Config.from_toml(Path(find_closest("config.toml"))).models.gpt3 + # Initialize memory + if self._memory is None: + self._memory = Cache(cache_client) + + @property + def memory(self) -> Cache: + """Get the memory""" + return self._memory + + @property + def last_chat(self) -> Optional[dict]: + """Get the memory""" + return self._memory.get_map("last_chat") + + @last_chat.setter + def last_chat(self, value: dict) -> None: + """Set the memory""" + self._memory.set_map( + "last_chat", + { + "query": value.get(b"query", ""), + "assistant": value.get(b"assistant", ""), + }, + ) + def stream_chat(self, user_prompt: str) -> Generator[str, None, None]: """ Stream a chat with GPT-3 @@ -38,15 +67,27 @@ def stream_chat(self, user_prompt: str) -> Generator[str, None, None]: Returns: None """ - completion = openai.ChatCompletion.create( - model=self.model.name, - messages=[ - {"role": "system", "content": self.agent.persona}, + intro_session = [ + {"role": "system", "content": self.agent.persona}, + { + "role": "user", + "content": f"User: {user_prompt}", + }, + { + "role": "system", + "content": self.agent.answer_format, + }, + ] + if self.last_chat: + intro_session += [ { "role": "user", - "content": f"User: {user_prompt}" f"\n{self.agent.answer_format}", - }, - ], + "content": f"{self.last_chat.get(b'query', '')}", + } + ] + completion = openai.ChatCompletion.create( + model=self.model.name, + messages=intro_session, **self.model.config.model_dump(), ) for chunk in completion: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a6e569b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ + +import pytest +import fakeredis + +@pytest.fixture(scope="session") +def fake_redis() -> fakeredis.FakeStrictRedis: + """ + Create a fake Redis client. + + Returns: FakeRedis client + """ + return fakeredis.FakeStrictRedis() \ No newline at end of file diff --git a/tests/helpers/test_cache.py b/tests/helpers/test_cache.py index 6ae21ea..3c7f343 100644 --- a/tests/helpers/test_cache.py +++ b/tests/helpers/test_cache.py @@ -1,5 +1,3 @@ -import fakeredis -import pytest from mentor_mingle.helpers.cache import Cache @@ -9,15 +7,6 @@ class TestCache: Test the Cache class. """ - @pytest.fixture - def fake_redis(self) -> fakeredis.FakeStrictRedis: - """ - Create a fake Redis client. - - Returns: FakeRedis client - """ - return fakeredis.FakeStrictRedis() - def test_get(self, fake_redis): """ Test the get method of the Cache class. diff --git a/tests/test_llm_handler.py b/tests/test_llm_handler.py index 1077b5a..f5c5439 100644 --- a/tests/test_llm_handler.py +++ b/tests/test_llm_handler.py @@ -13,11 +13,11 @@ class TestChatHandler: Test the ChatHandler class. """ - def test_chat_handler_init(self): + def test_chat_handler_init(self, fake_redis): """ Test the stream_chat method of the ChatHandler class. """ - handler = ChatHandler(Mentor()) + handler = ChatHandler(cache_client=fake_redis) assert isinstance(handler.model, Gpt) assert isinstance(handler.agent, Mentor) assert isinstance(handler.model.config, CFGGpt) @@ -40,7 +40,7 @@ def mock_response_generator(self, **kwargs) -> Generator[OpenAIObject, None, Non mock_obj.choices = [choice_mock] yield mock_obj - def test_stream_chat(self, mocker: MagicMock): + def test_stream_chat(self, mocker: MagicMock, fake_redis): """ Test the stream_chat method of the ChatHandler class. @@ -51,7 +51,7 @@ def test_stream_chat(self, mocker: MagicMock): None """ # Create a mock instance of your class - llm = ChatHandler(Mentor()) + llm = ChatHandler(cache_client=fake_redis) # Patch the API call to return mock_response mocker.patch.object(openai.ChatCompletion, "create", side_effect=self.mock_response_generator)