From bac77315ea56a7497e78561fc2d00b11c7ce7c4d Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Thu, 28 Mar 2024 06:03:15 -0400 Subject: [PATCH] refactor printing / logging code --- src/chatdbg/chatdbg_pdb.py | 84 ++----------------------- src/chatdbg/util/{chatlog.py => log.py} | 5 +- src/chatdbg/util/printer.py | 70 +++++++++++++++++++++ 3 files changed, 77 insertions(+), 82 deletions(-) rename src/chatdbg/util/{chatlog.py => log.py} (97%) create mode 100644 src/chatdbg/util/printer.py diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index e80e536..0d81f4a 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -15,21 +15,14 @@ import IPython from traitlets import TraitError -from chatdbg.util.capture import CaptureInput, CaptureOutput - from .assistant.assistant import Assistant -from .assistant.listeners import BaseAssistantListener -from .util.chatlog import ChatDBGLog +from .util.capture import CaptureInput, CaptureOutput from .util.config import chatdbg_config from .util.locals import extract_locals +from .util.log import ChatDBGLog +from .util.printer import ChatDBGPrinter from .util.prompts import pdb_instructions -from .util.streamwrap import StreamingTextWrapper -from .util.text import ( - format_limited, - strip_color, - truncate_proportionally, - word_wrap_except_code_blocks, -) +from .util.text import (format_limited, strip_color, truncate_proportionally) def load_ipython_extension(ipython): @@ -636,7 +629,7 @@ def _make_assistant(self): functions=functions, stream=chatdbg_config.stream, listeners=[ - ChatAssistantClient( + ChatDBGPrinter( self.stdout, self.prompt, self._chat_prefix, @@ -717,70 +710,3 @@ def slice(self, name): command = f"slice {name}" result = self._capture_onecmd(command) return command, truncate_proportionally(result, top_proportion=0.5) - - ############################################################### - - -class ChatAssistantClient(BaseAssistantListener): - def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): - self.out = out - self.debugger_prompt = debugger_prompt - self.chat_prefix = chat_prefix - self.width = width - self._assistant = None - self._stream = stream - - # Call backs - - def on_begin_query(self, prompt, user_text): - pass - - def on_end_query(self, stats): - pass - - def _print(self, text, **kwargs): - print( - textwrap.indent(text, self.chat_prefix, lambda _: True), - file=self.out, - **kwargs, - ) - - def on_warn(self, text): - self._print(textwrap.indent(text, "*** ")) - - def on_fail(self, text): - self._print(textwrap.indent(text, "*** ")) - - def on_begin_stream(self): - self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) - self._at_start = True - - def on_stream_delta(self, text): - if self._at_start: - self._at_start = False - print( - self._stream_wrapper.append("\n(Message) ", False), - end="", - flush=True, - file=self.out, - ) - print( - self._stream_wrapper.append(text, False), end="", flush=True, file=self.out - ) - - def on_end_stream(self): - print(self._stream_wrapper.flush(), end="", flush=True, file=self.out) - - def on_response(self, text): - if not self._stream and text != None: - text = word_wrap_except_code_blocks( - text, self.width - len(self.chat_prefix) - ) - self._print(text) - - def on_function_call(self, call, result): - if result and len(result) > 0: - entry = f"{self.debugger_prompt}{call}\n{result}" - else: - entry = f"{self.debugger_prompt}{call}" - self._print(entry) diff --git a/src/chatdbg/util/chatlog.py b/src/chatdbg/util/log.py similarity index 97% rename from src/chatdbg/util/chatlog.py rename to src/chatdbg/util/log.py index 7e65e79..0f086b0 100644 --- a/src/chatdbg/util/chatlog.py +++ b/src/chatdbg/util/log.py @@ -4,10 +4,9 @@ import yaml -from chatdbg.util.capture import CaptureOutput - from ..assistant.listeners import BaseAssistantListener -from ..util.text import word_wrap_except_code_blocks +from ..util.capture import CaptureOutput +from .text import word_wrap_except_code_blocks class ChatDBGLog(BaseAssistantListener): diff --git a/src/chatdbg/util/printer.py b/src/chatdbg/util/printer.py new file mode 100644 index 0000000..c3a36e5 --- /dev/null +++ b/src/chatdbg/util/printer.py @@ -0,0 +1,70 @@ + +import textwrap +from assistant.listeners import BaseAssistantListener +from util.streamwrap import StreamingTextWrapper +from util.text import word_wrap_except_code_blocks + + +class ChatDBGPrinter(BaseAssistantListener): + def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): + self.out = out + self.debugger_prompt = debugger_prompt + self.chat_prefix = chat_prefix + self.width = width + self._assistant = None + self._stream = stream + + # Call backs + + def on_begin_query(self, prompt, user_text): + pass + + def on_end_query(self, stats): + pass + + def _print(self, text, **kwargs): + print( + textwrap.indent(text, self.chat_prefix, lambda _: True), + file=self.out, + **kwargs, + ) + + def on_warn(self, text): + self._print(textwrap.indent(text, "*** ")) + + def on_fail(self, text): + self._print(textwrap.indent(text, "*** ")) + + def on_begin_stream(self): + self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) + self._at_start = True + + def on_stream_delta(self, text): + if self._at_start: + self._at_start = False + print( + self._stream_wrapper.append("\n(Message) ", False), + end="", + flush=True, + file=self.out, + ) + print( + self._stream_wrapper.append(text, False), end="", flush=True, file=self.out + ) + + def on_end_stream(self): + print(self._stream_wrapper.flush(), end="", flush=True, file=self.out) + + def on_response(self, text): + if not self._stream and text != None: + text = word_wrap_except_code_blocks( + text, self.width - len(self.chat_prefix) + ) + self._print(text) + + def on_function_call(self, call, result): + if result and len(result) > 0: + entry = f"{self.debugger_prompt}{call}\n{result}" + else: + entry = f"{self.debugger_prompt}{call}" + self._print(entry)