Skip to content

Commit

Permalink
Added improved memory handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Xpitfire committed Jun 10, 2023
1 parent 89b6c7e commit de3c0ec
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions symai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ChatBot(Expression):
_symai_chat: str = """This is a conversation between a chatbot ({}:) and a human (User:). It also includes narration Text (Narrator:) describing the next dialog. The chatbot primarily follows the narrative instructions, and then uses the user input to condition on for the generated response.\n"""

def __init__(self, value = None, name: str = 'Symbia', output: Optional[Output] = None, verbose: bool = False):
def __init__(self, value = None, name: str = 'Symbia', output: Optional[Output] = None, verbose: bool = True):
super().__init__(value)
self.verbose: bool = verbose
self.name = name
Expand Down Expand Up @@ -39,7 +39,7 @@ def narrate(self, message: str, context: str = None, end: bool = False, do_recal
value += '\n'.join(self.memory.recall()) # TODO: use vector search DB
value += '\n\nLONG-TERM MEMORY RECALL (Consider only if relevant to the user query!)\n\n'
query = f'{self.last_user_input}\n{message}\n\n'
recall = self.long_term_memory.recall(query) if do_recall else []
recall = self.long_term_memory.recall(query) if do_recall else []
value += '\n'.join(recall)
value += f'\n{self.name}:'

Expand All @@ -55,7 +55,9 @@ def _func(_) -> str:

rsp = f"{self.name}: {model_rsp}"
if self.verbose: print('[DEBUG] model reply: ', model_rsp)
memory = f"{self.last_user_input} >> {model_rsp}" if do_recall else ""
memory = f"{self.last_user_input}" if do_recall else ""
if len(memory) > 0: self.long_term_memory.store(memory)
memory = f"{model_rsp}" if do_recall else ""
if len(memory) > 0: self.long_term_memory.store(memory)
if self.verbose: print('[DEBUG] store new memory reply: ', memory)

Expand Down Expand Up @@ -137,7 +139,7 @@ def forward(self):

elif '[DK]' in ctxt:
thought = self._extract_thought(ctxt)
message = self.narrate(f'{self.name} restates verbatim the message.', context=thought)
message = self.narrate(f'{self.name} restates verbatim the message.', context=thought, do_recall=False)

else:
try:
Expand Down Expand Up @@ -200,7 +202,7 @@ def forward(self):

except Exception as e:
thought = self._extract_thought(ctxt)
message = self.narrate('Symbia apologizes and explains the user what went wrong.', context=str(e))
message = self.narrate('Symbia apologizes and explains the user what went wrong.', context=str(e), do_recall=False)

def _extract_thought(self, msg: str) -> str:
return re.findall(r'\{([^}]+)\}', msg).pop()
Expand Down

0 comments on commit de3c0ec

Please sign in to comment.