Skip to content

Commit

Permalink
refactor printing / logging code
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenfreund committed Mar 28, 2024
1 parent 073bbd4 commit bac7731
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 82 deletions.
84 changes: 5 additions & 79 deletions src/chatdbg/chatdbg_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,14 @@
import IPython
from traitlets import TraitError

from chatdbg.util.capture import CaptureInput, CaptureOutput

from .assistant.assistant import Assistant
from .assistant.listeners import BaseAssistantListener
from .util.chatlog import ChatDBGLog
from .util.capture import CaptureInput, CaptureOutput
from .util.config import chatdbg_config
from .util.locals import extract_locals
from .util.log import ChatDBGLog
from .util.printer import ChatDBGPrinter
from .util.prompts import pdb_instructions
from .util.streamwrap import StreamingTextWrapper
from .util.text import (
format_limited,
strip_color,
truncate_proportionally,
word_wrap_except_code_blocks,
)
from .util.text import (format_limited, strip_color, truncate_proportionally)


def load_ipython_extension(ipython):
Expand Down Expand Up @@ -636,7 +629,7 @@ def _make_assistant(self):
functions=functions,
stream=chatdbg_config.stream,
listeners=[
ChatAssistantClient(
ChatDBGPrinter(
self.stdout,
self.prompt,
self._chat_prefix,
Expand Down Expand Up @@ -717,70 +710,3 @@ def slice(self, name):
command = f"slice {name}"
result = self._capture_onecmd(command)
return command, truncate_proportionally(result, top_proportion=0.5)

###############################################################


class ChatAssistantClient(BaseAssistantListener):
def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False):
self.out = out
self.debugger_prompt = debugger_prompt
self.chat_prefix = chat_prefix
self.width = width
self._assistant = None
self._stream = stream

# Call backs

def on_begin_query(self, prompt, user_text):
pass

def on_end_query(self, stats):
pass

def _print(self, text, **kwargs):
print(
textwrap.indent(text, self.chat_prefix, lambda _: True),
file=self.out,
**kwargs,
)

def on_warn(self, text):
self._print(textwrap.indent(text, "*** "))

def on_fail(self, text):
self._print(textwrap.indent(text, "*** "))

def on_begin_stream(self):
self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80)
self._at_start = True

def on_stream_delta(self, text):
if self._at_start:
self._at_start = False
print(
self._stream_wrapper.append("\n(Message) ", False),
end="",
flush=True,
file=self.out,
)
print(
self._stream_wrapper.append(text, False), end="", flush=True, file=self.out
)

def on_end_stream(self):
print(self._stream_wrapper.flush(), end="", flush=True, file=self.out)

def on_response(self, text):
if not self._stream and text != None:
text = word_wrap_except_code_blocks(
text, self.width - len(self.chat_prefix)
)
self._print(text)

def on_function_call(self, call, result):
if result and len(result) > 0:
entry = f"{self.debugger_prompt}{call}\n{result}"
else:
entry = f"{self.debugger_prompt}{call}"
self._print(entry)
5 changes: 2 additions & 3 deletions src/chatdbg/util/chatlog.py → src/chatdbg/util/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

import yaml

from chatdbg.util.capture import CaptureOutput

from ..assistant.listeners import BaseAssistantListener
from ..util.text import word_wrap_except_code_blocks
from ..util.capture import CaptureOutput
from .text import word_wrap_except_code_blocks


class ChatDBGLog(BaseAssistantListener):
Expand Down
70 changes: 70 additions & 0 deletions src/chatdbg/util/printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import textwrap
from assistant.listeners import BaseAssistantListener
from util.streamwrap import StreamingTextWrapper
from util.text import word_wrap_except_code_blocks


class ChatDBGPrinter(BaseAssistantListener):
def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False):
self.out = out
self.debugger_prompt = debugger_prompt
self.chat_prefix = chat_prefix
self.width = width
self._assistant = None
self._stream = stream

# Call backs

def on_begin_query(self, prompt, user_text):
pass

def on_end_query(self, stats):
pass

def _print(self, text, **kwargs):
print(
textwrap.indent(text, self.chat_prefix, lambda _: True),
file=self.out,
**kwargs,
)

def on_warn(self, text):
self._print(textwrap.indent(text, "*** "))

def on_fail(self, text):
self._print(textwrap.indent(text, "*** "))

def on_begin_stream(self):
self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80)
self._at_start = True

def on_stream_delta(self, text):
if self._at_start:
self._at_start = False
print(
self._stream_wrapper.append("\n(Message) ", False),
end="",
flush=True,
file=self.out,
)
print(
self._stream_wrapper.append(text, False), end="", flush=True, file=self.out
)

def on_end_stream(self):
print(self._stream_wrapper.flush(), end="", flush=True, file=self.out)

def on_response(self, text):
if not self._stream and text != None:
text = word_wrap_except_code_blocks(
text, self.width - len(self.chat_prefix)
)
self._print(text)

def on_function_call(self, call, result):
if result and len(result) > 0:
entry = f"{self.debugger_prompt}{call}\n{result}"
else:
entry = f"{self.debugger_prompt}{call}"
self._print(entry)

0 comments on commit bac7731

Please sign in to comment.