diff --git a/src/chatdbg/__main__.py b/src/chatdbg/__main__.py index b070763..f876884 100644 --- a/src/chatdbg/__main__.py +++ b/src/chatdbg/__main__.py @@ -1,7 +1,57 @@ -from .chatdbg_pdb import * import ipdb +from chatdbg.chatdbg_pdb import ChatDBG +from chatdbg.util.config import chatdbg_config +import sys +import getopt + +_usage = """\ +usage: python -m ipdb [-m] [-c command] ... pyfile [arg] ... + +Debug the Python program given by pyfile. + +Initial commands are read from .pdbrc files in your home directory +and in the current directory, if they exist. Commands supplied with +-c are executed after commands from .pdbrc files. + +To let the script run until an exception occurs, use "-c continue". +To let the script run up to a given line X in the debugged file, use +"-c 'until X'" + +Option -m is available only in Python 3.7 and later. + +ChatDBG-specific options may appear anywhere before pyfile: + --debug dump the LLM messages to a chatdbg.log + --log file where to write the log of the debugging session + --model model the LLM model to use. + --stream stream responses from the LLM + +""" def main(): ipdb.__main__._get_debugger_cls = lambda: ChatDBG + + opts, args = getopt.getopt( + sys.argv[1:], "mhc:", ["help", "debug", "log=", "model=", "stream", "command="] + ) + pdb_args = [sys.argv[0]] + for opt, optarg in opts: + if opt in ["-h", "--help"]: + print(_usage) + sys.exit() + elif opt in ["--debug"]: + chatdbg_config.debug = True + elif opt in ["--stream"]: + chatdbg_config.stream = True + elif opt in ["--model"]: + chatdbg_config.model = optarg + elif opt in ["--log"]: + chatdbg_config.model = optarg + elif opt in ["-c", "--command"]: + pdb_args += [opt, optarg] + elif opt in ["-m"]: + pdb_args = [opt] + + sys.argv = pdb_args + args + ipdb.__main__.main() diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 55c1b92..5a67bbc 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -1,259 +1,306 @@ -import atexit -import inspect import json import sys import textwrap import time -import llm_utils -from openai import * -from pydantic import BaseModel +import litellm +import openai +from .listeners import Printer -class Assistant: - """ - An Assistant is a wrapper around OpenAI's assistant API. Example usage: - - assistant = Assistant("Assistant Name", instructions, - model='gpt-4-1106-preview', debug=True) - assistant.add_function(my_func) - response = assistant.run(user_prompt) - - Name can be any name you want. - - If debug is True, it will create a log of all messages and JSON responses in - json.txt. - """ - # TODO: At some point, if we unify the argument parsing, we should just have this take args. +class Assistant: def __init__( - self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, debug=True + self, + instructions, + model="gpt-3.5-turbo-1106", + timeout=30, + listeners=[Printer()], + functions=[], + max_call_response_tokens=4096, + debug=False, + stream=False, ): if debug: - self.json = open(f"json.txt", "a") - else: - self.json = None - - try: - self.client = OpenAI(timeout=timeout) - except OpenAIError: - print( - textwrap.dedent( - """\ - You need an OpenAI key to use this tool. - You can get a key here: https://platform.openai.com/api-keys - Set the environment variable OPENAI_API_KEY to your key value. - """ - ) + log_file = open(f"chatdbg.log", "w") + self._logger = lambda model_call_dict: print( + model_call_dict, file=log_file, flush=True ) - sys.exit(-1) - - self.assistants = self.client.beta.assistants - self.threads = self.client.beta.threads - self.functions = dict() + else: + self._logger = None - self.assistant = self.assistants.create( - name=name, instructions=instructions, model=model - ) + self._clients = listeners - self._log(self.assistant) + self._functions = {} + for f in functions: + self._add_function(f) - atexit.register(self._delete_assistant) + self._model = model + self._timeout = timeout + self._conversation = [{"role": "system", "content": instructions}] + self._max_call_response_tokens = max_call_response_tokens + self._stream = stream - self.thread = self.threads.create() - self._log(self.thread) + self._check_model() + self._broadcast("on_begin_dialog", instructions) - def _delete_assistant(self): - if self.assistant != None: - try: - id = self.assistant.id - response = self.assistants.delete(id) - self._log(response) - assert response.deleted - except OSError: - raise - except Exception as e: - print(f"Assistant {id} was not deleted ({e}).") - print("You can do so at https://platform.openai.com/assistants.") + def close(self): + self._broadcast("on_end_dialog") - def add_function(self, function): + def query(self, prompt: str, user_text): """ - Add a new function to the list of function tools for the assistant. - The function should have the necessary json spec as its pydoc string. + Send a query to the LLM. + - prompt is the prompt to send. + - user_text is what the user typed (which may or not be the same as prompt) """ - function_json = json.loads(function.__doc__) - assert "name" in function_json, "Bad JSON in pydoc for function tool." - try: - name = function_json["name"] - self.functions[name] = function + if self._stream: + return self._streamed_query(prompt, user_text) + else: + return self._batch_query(prompt, user_text) + + def _broadcast(self, method_name, *args): + for client in self._clients: + method = getattr(client, method_name, None) + if callable(method): + method(*args) + + def _check_model(self): + result = litellm.validate_environment(self._model) + missing_keys = result["missing_keys"] + if missing_keys != []: + _, provider, _, _ = litellm.get_llm_provider(self._model) + if provider == "openai": + self._broadcast( + "on_fail", + textwrap.dedent( + f"""\ + You need an OpenAI key to use the {self._model} model. + You can get a key here: https://platform.openai.com/api-keys. + Set the environment variable OPENAI_API_KEY to your key value.""" + ), + ) + sys.exit(1) + else: + self._broadcast( + "on_fail", + textwrap.dedent( + f"""\ + You need to set the following environment variables + to use the {self._model} model: {', '.join(missing_keys)}""" + ), + ) + sys.exit(1) - tools = [ - {"type": "function", "function": json.loads(function.__doc__)} - for function in self.functions.values() - ] + if not litellm.supports_function_calling(self._model): + self._broadcast( + "on_fail", + textwrap.dedent( + f"""\ + The {self._model} model does not support function calls. + You must use a model that does, eg. gpt-4.""" + ), + ) + sys.exit(1) + + def _sandwhich_tokens( + self, text: str, max_tokens: int, top_proportion: float + ) -> str: + model = self._model + if max_tokens == None: + return text + tokens = litellm.encode(model, text) + if len(tokens) <= max_tokens: + return text + else: + total_len = max_tokens - 5 # some slop for the ... + top_len = int(top_proportion * total_len) + bot_len = int((1 - top_proportion) * total_len) + return ( + litellm.decode(model, tokens[0:top_len]) + + " [...] " + + litellm.decode(model, tokens[-bot_len:]) + ) - assistant = self.assistants.update(self.assistant.id, tools=tools) - self._log(assistant) - except OpenAIError as e: - print(f"*** OpenAI Error: {e}") - sys.exit(-1) + def _add_function(self, function): + """ + Add a new function to the list of function tools. + The function should have the necessary json spec as its docstring + """ + schema = json.loads(function.__doc__) + assert "name" in schema, "Bad JSON in docstring for function tool." + self._functions[schema["name"]] = {"function": function, "schema": schema} - def _make_call(self, tool_call): + def _make_call(self, tool_call) -> str: name = tool_call.function.name - args = tool_call.function.arguments - - # There is a sketchy case that happens occasionally because - # the API produces a bad call... try: - args = json.loads(args) - function = self.functions[name] - result = function(**args) + args = json.loads(tool_call.function.arguments) + function = self._functions[name] + call, result = function["function"](**args) + self._broadcast("on_function_call", call, result) except OSError as e: + # function produced some error -- move this to client??? result = f"Error: {e}" - # raise except Exception as e: - result = f"Ill-formed function call ({e})\n" - + result = f"Ill-formed function call: {e}" return result - def _print_messages(self, messages, client_print): - client_print() - for i, m in enumerate(messages): - message_text = m.content[0].text.value - if i == 0: - message_text = "(Message) " + message_text - client_print(message_text) + def _batch_query(self, prompt: str, user_text): + start = time.time() + cost = 0 - def _wait_on_run(self, run, thread, client_print): try: - while run.status == "queued" or run.status == "in_progress": - run = self.threads.runs.retrieve( - thread_id=thread.id, - run_id=run.id, + self._broadcast("on_begin_query", prompt, user_text) + self._conversation.append({"role": "user", "content": prompt}) + + while True: + self._conversation = litellm.utils.trim_messages( + self._conversation, self._model ) - time.sleep(0.5) - return run - finally: - if run.status == "in_progress": - client_print("Cancelling message that's in progress.") - self.threads.runs.cancel(thread_id=thread.id, run_id=run.id) - - def run(self, prompt, client_print=print): - """ - Give the prompt to the assistant and get the response, which may included - intermediate function calls. - All output is printed to the given file. - """ - start_time = time.perf_counter() - try: - if self.assistant == None: - return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, - "model": self.assistant.model, - "cost": cost, - } + completion = self._completion() - assert len(prompt) <= 32768 + cost += litellm.completion_cost(completion) - message = self.threads.messages.create( - thread_id=self.thread.id, role="user", content=prompt - ) - self._log(message) + response_message = completion.choices[0].message + self._conversation.append(response_message) - last_printed_message_id = message.id + if response_message.content: + self._broadcast( + "on_response", "(Message) " + response_message.content + ) - run = self.threads.runs.create( - thread_id=self.thread.id, assistant_id=self.assistant.id - ) - self._log(run) + if completion.choices[0].finish_reason == "tool_calls": + self._add_function_results_to_conversation(response_message) + else: + break - run = self._wait_on_run(run, self.thread, client_print) - self._log(run) + elapsed = time.time() - start + stats = { + "cost": cost, + "time": elapsed, + "model": self._model, + "tokens": completion.usage.total_tokens, + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, + } + self._broadcast("on_end_query", stats) + return stats + except openai.OpenAIError as e: + self._broadcast("on_fail", f"Internal Error: {e.__dict__}") + sys.exit(1) + + def _streamed_query(self, prompt: str, user_text): + start = time.time() + cost = 0 - while run.status == "requires_action": - messages = self.threads.messages.list( - thread_id=self.thread.id, after=last_printed_message_id, order="asc" + try: + self._broadcast("on_begin_query", prompt, user_text) + self._conversation.append({"role": "user", "content": prompt}) + + while True: + self._conversation = litellm.utils.trim_messages( + self._conversation, self._model + ) + # print("\n".join([str(x) for x in self._conversation])) + + stream = self._completion(stream=True) + + # litellm.stream_chunk_builder is broken for new GPT models + # that have content before calls, so... + + # stream the response, collecting the tool_call parts separately + # from the content + self._broadcast("on_begin_stream") + chunks = [] + tool_chunks = [] + for chunk in stream: + chunks.append(chunk) + if chunk.choices[0].delta.content != None: + self._broadcast( + "on_stream_delta", chunk.choices[0].delta.content + ) + else: + tool_chunks.append(chunk) + self._broadcast("on_end_stream") + + # then compute for the part that litellm gives back. + completion = litellm.stream_chunk_builder( + chunks, messages=self._conversation ) + cost += litellm.completion_cost(completion) - mlist = list(messages) - if len(mlist) > 0: - self._print_messages(mlist, client_print) - last_printed_message_id = mlist[-1].id - client_print() - - outputs = [] - for tool_call in run.required_action.submit_tool_outputs.tool_calls: - output = self._make_call(tool_call) - self._log(output) - outputs += [{"tool_call_id": tool_call.id, "output": output}] - - try: - run = self.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, run_id=run.id, tool_outputs=outputs + # add content to conversation, but if there is no content, then the message + # has only tool calls, and skip this step + response_message = completion.choices[0].message + if response_message.content != None: + self._conversation.append(response_message) + + if response_message.content != None: + self._broadcast( + "on_response", "(Message) " + response_message.content ) - self._log(run) - except OSError as e: - self._log(run, f"FAILED to submit tool call results: {e}") - raise - except Exception as e: - self._log(run, f"FAILED to submit tool call results: {e}") - - run = self._wait_on_run(run, self.thread, client_print) - self._log(run) - - if run.status == "failed": - message = f"\n**Internal Failure ({run.last_error.code}):** {run.last_error.message}" - client_print(message) - self._log(run) - sys.exit(-1) - - messages = self.threads.messages.list( - thread_id=self.thread.id, after=last_printed_message_id, order="asc" - ) - self._print_messages(messages, client_print) - end_time = time.perf_counter() - elapsed_time = end_time - start_time + if completion.choices[0].finish_reason == "tool_calls": + # create a message with just the tool calls, append that to the conversation, and generate the responses. + tool_completion = litellm.stream_chunk_builder( + tool_chunks, self._conversation + ) - cost = llm_utils.calculate_cost( - run.usage.prompt_tokens, - run.usage.completion_tokens, - self.assistant.model, - ) - client_print() - client_print(f"[Cost: ~${cost:.2f} USD]") - - return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, - "model": self.assistant.model, + # this part wasn't counted above... + cost += litellm.completion_cost(tool_completion) + + tool_message = tool_completion.choices[0].message + cost += litellm.completion_cost(tool_completion) + self._conversation.append(tool_message) + self._add_function_results_to_conversation(tool_message) + else: + break + + elapsed = time.time() - start + stats = { "cost": cost, - "time": elapsed_time, - "thread.id": self.thread.id, - "run.id": run.id, - "assistant.id": self.assistant.id, + "time": elapsed, + "model": self._model, + "tokens": completion.usage.total_tokens, + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, } - except OpenAIError as e: - client_print(f"*** OpenAI Error: {e}") - sys.exit(-1) - - def _log(self, obj, title=""): - if self.json != None: - stack = inspect.stack() - caller_frame_record = stack[1] - lineno, function = caller_frame_record[2:4] - loc = f"{function}:{lineno}" - - print("-" * 70, file=self.json) - print(f"{loc} {title}", file=self.json) - if isinstance(obj, BaseModel): - json_obj = json.loads(obj.model_dump_json()) - else: - json_obj = obj - print(f"\n{json.dumps(json_obj, indent=2)}\n", file=self.json) - self.json.flush() - return obj + self._broadcast("on_end_query", stats) + return stats + except openai.OpenAIError as e: + self._broadcast("on_fail", f"Internal Error: {e.__dict__}") + sys.exit(1) + + def _completion(self, stream=False): + return litellm.completion( + model=self._model, + messages=self._conversation, + tools=[ + {"type": "function", "function": f["schema"]} + for f in self._functions.values() + ], + timeout=self._timeout, + logger_fn=self._logger, + stream=stream, + ) + + def _add_function_results_to_conversation(self, response_message): + response_message["role"] = "assistant" + tool_calls = response_message.tool_calls + try: + for tool_call in tool_calls: + function_response = self._make_call(tool_call) + function_response = self._sandwhich_tokens( + function_response, self._max_call_response_tokens, 0.5 + ) + response = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": function_response, + } + self._conversation.append(response) + except Exception as e: + # Warning: potential infinite loop if the LLM keeps sending + # the same bad call. + self._broadcast("on_warn", f"Error processing tool calls: {e}") diff --git a/src/chatdbg/assistant/listeners.py b/src/chatdbg/assistant/listeners.py new file mode 100644 index 0000000..4f70754 --- /dev/null +++ b/src/chatdbg/assistant/listeners.py @@ -0,0 +1,104 @@ +import sys +import textwrap + + +class BaseAssistantListener: + """ + Events that the Assistant generates. Override these for the client. + """ + + # Dialogs capture 1 or more queries. + + def on_begin_dialog(self, instructions): + pass + + def on_end_dialog(self): + pass + + # Events for a single query + + def on_begin_query(self, prompt, user_text): + pass + + def on_response(self, text): + pass + + def on_function_call(self, call, result): + pass + + def on_end_query(self, stats): + pass + + # For clients wishing to stream responses + + def on_begin_stream(self): + pass + + def on_stream_delta(self, text): + pass + + def on_end_stream(self): + pass + + # Notifications of non-fatal / fatal problems + + def on_warn(self, text): + pass + + def on_fail(self, text): + pass + + +class Printer(BaseAssistantListener): + def __init__(self, out=sys.stdout): + self.out = out + + def on_warn(self, text): + print(textwrap.indent(text, "*** "), file=self.out) + + def on_fail(self, text): + print(textwrap.indent(text, "*** "), file=self.out) + sys.exit(1) + + def on_begin_stream(self): + pass + + def on_stream_delta(self, text): + print(text, end="", file=self.out, flush=True) + + def on_end_stream(self): + pass + + def on_begin_query(self, prompt, user_text): + pass + + def on_end_query(self, stats): + pass + + def on_response(self, text): + if text != None: + print(text, file=self.out) + + def on_function_call(self, call, result): + if result and len(result) > 0: + entry = f"{call}\n{result}" + else: + entry = f"{call}" + print(entry, file=self.out) + + +class StreamingPrinter(Printer): + def __init__(self, out=sys.stdout): + super().__init__(out) + + def on_begin_stream(self): + print("", flush=True) + + def on_stream_delta(self, text): + print(text, end="", file=self.out, flush=True) + + def on_end_stream(self): + print("", flush=True) + + def on_response(self, text): + pass diff --git a/src/chatdbg/assistant/test.py b/src/chatdbg/assistant/test.py new file mode 100644 index 0000000..db79c13 --- /dev/null +++ b/src/chatdbg/assistant/test.py @@ -0,0 +1,52 @@ +from .assistant import Assistant +from .listeners import StreamingPrinter, Printer + + +class AssistantTest: + + def __init__(self): + self.a = Assistant( + "You generate text.", + listeners=[StreamingPrinter()], + functions=[self.weather], + stream=True, + ) + + def run(self): + x = self.a.query( + "tell me what model you are before making any function calls. And what's the weather in Boston?", + None, + ) + print(x) + + def weather(self, location, unit="f"): + """ + { + "name": "get_weather", + "description": "Determine weather in my location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + } + """ + return f"weather({location}, {unit})", "Sunny and 72 degrees." + + +t = AssistantTest() +t.run() diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 9993a59..e80e536 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -13,38 +13,32 @@ from pprint import pprint import IPython -import llm_utils from traitlets import TraitError -from chatdbg.ipdb_util.capture import CaptureInput +from chatdbg.util.capture import CaptureInput, CaptureOutput from .assistant.assistant import Assistant -from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper -from .ipdb_util.config import Chat -from .ipdb_util.locals import * -from .ipdb_util.prompts import pdb_instructions -from .ipdb_util.text import * - -_valid_models = [ - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - "gpt-4", # no parallel calls - "gpt-3.5-turbo", # no parallel calls -] - -chatdbg_config: Chat = None +from .assistant.listeners import BaseAssistantListener +from .util.chatlog import ChatDBGLog +from .util.config import chatdbg_config +from .util.locals import extract_locals +from .util.prompts import pdb_instructions +from .util.streamwrap import StreamingTextWrapper +from .util.text import ( + format_limited, + strip_color, + truncate_proportionally, + word_wrap_except_code_blocks, +) def load_ipython_extension(ipython): - # Create an instance of your configuration class with IPython's config global chatdbg_config - from chatdbg.chatdbg_pdb import Chat, ChatDBG + from chatdbg.chatdbg_pdb import ChatDBG + from chatdbg.util.config import ChatDBGConfig, chatdbg_config ipython.InteractiveTB.debugger_cls = ChatDBG - chatdbg_config = Chat(config=ipython.config) + chatdbg_config = ChatDBGConfig(config=ipython.config) print("*** Loaded ChatDBG ***") @@ -84,20 +78,36 @@ def __init__(self, *args, **kwargs): self._chat_prefix = " " self._text_width = 80 self._assistant = None + atexit.register(lambda: self._close_assistant()) + self._history = [] self._error_specific_prompt = "" - global chatdbg_config - if chatdbg_config == None: - chatdbg_config = Chat() - sys.stdin = CaptureInput(sys.stdin) + self._supports_flow = self.can_support_flow() + + self.do_context(chatdbg_config.context) + self.rcLines += ast.literal_eval(chatdbg_config.rc_lines) + + # set this to True ONLY AFTER we have had access to stack frames + self._show_locals = False + + self._log = ChatDBGLog( + log_filename=chatdbg_config.log, + config=chatdbg_config.to_json(), + capture_streams=True, + ) + + def _close_assistant(self): + if self._assistant != None: + self._assistant.close() + + def can_support_flow(self): # Only use flow when we are in jupyter or using stdin in ipython. In both # cases, there will be no python file at the start of argv after the # ipython commands. - self._supports_flow = chatdbg_config.show_slices - if self._supports_flow: + if chatdbg_config.show_slices: if ChatDBGSuper is not IPython.core.debugger.InterruptiblePdb: for arg in sys.argv: if arg.endswith("ipython") or arg.endswith("ipython3"): @@ -105,17 +115,10 @@ def __init__(self, *args, **kwargs): if arg.startswith("-"): continue if Path(arg).suffix in [".py", ".ipy"]: - self._supports_flow = False - break - - self.do_context(chatdbg_config.context) - self.rcLines += ast.literal_eval(chatdbg_config.rc_lines) - - # set this to True ONLY AFTER we have had stack frames - self._show_locals = False - - self._log = ChatDBGLog(chatdbg_config) - atexit.register(lambda: self._log.dump()) + return False + return True + else: + return False def _is_user_frame(self, frame): if not self._is_user_file(frame.f_code.co_filename): @@ -126,11 +129,13 @@ def _is_user_frame(self, frame): def _is_user_file(self, file_name): if file_name.endswith(".pyx"): return False - if file_name == "": + elif file_name == "": return False + for prefix in _user_file_prefixes: if file_name.startswith(prefix): return True + return False def format_stack_trace(self, context=None): @@ -187,9 +192,7 @@ def _hide_lib_frames(self): self._error_specific_prompt += f"The code `{current_line.strip()}` is correct and MUST remain unchanged in your fix.\n" def execRcLines(self): - # do before running rclines -- our stack should be set up by now. - if not chatdbg_config.show_libs: self._hide_lib_frames() self._error_stack_trace = f"The program has the following stack trace:\n```\n{self.format_stack_trace()}\n```\n" @@ -207,21 +210,24 @@ def onecmd(self, line: str) -> bool: # blank -- let super call back to into onecmd return super().onecmd(line) else: - hist_file = CopyingTextIOWrapper(self.stdout) + hist_file = CaptureOutput(self.stdout) self.stdout = hist_file try: - self.was_chat = False + self.was_chat_or_renew = False return super().onecmd(line) finally: self.stdout = hist_file.getfile() - if not line.startswith("config") and not line.startswith("mark"): - output = strip_color(hist_file.getvalue()) - if line not in ["quit", "EOF"]: - self._log.user_command(line, output) - if ( - line not in ["hist", "test_prompt", "c", "continue"] - and not self.was_chat - ): + output = strip_color(hist_file.getvalue()) + if not self.was_chat_or_renew: + self._log.on_function_call(line, output) + if line.split(" ")[0] not in [ + "hist", + "test_prompt", + "c", + "cont", + "continue", + "config", + ]: self._history += [(line, output)] def message(self, msg) -> None: @@ -235,7 +241,6 @@ def error(self, msg) -> None: Override to remove tabs for messages so we can indent them. """ return super().error(str(msg).expandtabs()) - # return super().error('If the name is undefined, be sure you are in the right frame. Use up and down to do that, and then print the variable again'.expandtabs()) def _capture_onecmd(self, line): """ @@ -247,6 +252,7 @@ def _capture_onecmd(self, line): self.stdout = StringIO() super().onecmd(line) result = self.stdout.getvalue().rstrip() + result = strip_color(result) return result finally: self.stdout = stdout @@ -318,9 +324,7 @@ def do_info(self, arg): # didn't find anything if obj == None: self.message(f"No name `{arg}` is visible in the current frame.") - return - - if self._is_user_file(inspect.getfile(obj)): + elif self._is_user_file(inspect.getfile(obj)): self.do_source(x) else: self.do_pydoc(x) @@ -339,6 +343,11 @@ def do_info(self, arg): ) def do_slice(self, arg): + """ + slice + Print the backwards slice for a variable used in the current cell but + defined in an earlier cell. [interactive IPython / Jupyter only] + """ if not self._supports_flow: self.message("*** `slice` is only supported in Jupyter notebooks") return @@ -361,11 +370,6 @@ def do_slice(self, arg): break index -= 1 if _x != None: - # print('found it') - # print(_x) - # print(_x.__dict__) - # print(_x._get_timestamps_for_version(version=-1)) - # print(code(_x)) time_stamps = _x._get_timestamps_for_version(version=-1) time_stamps = [ts for ts in time_stamps if ts.cell_num > -1] result = str( @@ -434,7 +438,7 @@ def _hidden_predicate(self, frame): return False def print_stack_trace(self, context=None, locals=None): - # override to print the skips into stdout... + # override to print the skips into stdout instead of stderr... Colors = self.color_scheme_table.active_colors ColorsNormal = Colors.Normal if context is None: @@ -515,7 +519,7 @@ def _stack_prompt(self): ) stack = ( textwrap.dedent( - f""" + f"""\ This is the current stack. The current frame is indicated by an arrow '>' at the start of the line. ```""" @@ -526,68 +530,76 @@ def _stack_prompt(self): finally: self.stdout = stdout - def _build_prompt(self, arg, conversing): - prompt = "" + def _ip_instructions(self): + return pdb_instructions(self._supports_flow, chatdbg_config.take_the_wheel) - if not conversing: - stack_dump = f"The program has this stack trace:\n```\n{self.format_stack_trace()}\n```\n\n" - prompt = "\n" + stack_dump + self._error_specific_prompt - if len(sys.argv) > 1: - prompt += f"\nThese were the command line options:\n```\n{' '.join(sys.argv)}\n```\n" - input = sys.stdin.get_captured_input() - if len(input) > 0: - prompt += f"\nThis was the program's input :\n```\n{input}```\n" + def _ip_enchriched_stack_trace(self): + return f"The program has this stack trace:\n```\n{self.format_stack_trace()}\n```\n" + + def _ip_error(self): + return self._error_specific_prompt + + def _ip_inputs(self): + inputs = "" + if len(sys.argv) > 1: + inputs += f"\nThese were the command line options:\n```\n{' '.join(sys.argv)}\n```\n" + input = sys.stdin.get_captured_input() + if len(input) > 0: + inputs += f"\nThis was the program's input :\n```\n{input}```\n" + return inputs + def _ip_history(self): if len(self._history) > 0: hist = textwrap.indent(self._capture_onecmd("hist"), "") - self._clear_history() hist = f"\nThis is the history of some pdb commands I ran and the results.\n```\n{hist}\n```\n" - prompt += hist - - if arg == "why": - arg = "Explain the root cause of the error." + return hist + else: + return "" - stack = self._stack_prompt() - prompt += stack + "\n" + arg + def concat_prompt(self, *args): + args = [a for a in args if len(a) > 0] + return "\n".join(args) - return prompt + def _build_prompt(self, arg, conversing): + if not conversing: + return self.concat_prompt( + self._ip_enchriched_stack_trace(), + self._ip_inputs(), + self._ip_error(), + self._ip_history(), + arg, + ) + else: + return self.concat_prompt(self._ip_history(), self._stack_prompt(), arg) def do_chat(self, arg): - """chat/: + """chat Send a chat message. """ - self.was_chat = True + self.was_chat_or_renew = True full_prompt = self._build_prompt(arg, self._assistant != None) + full_prompt = strip_color(full_prompt) + full_prompt = truncate_proportionally(full_prompt) + + self._clear_history() if self._assistant == None: self._make_assistant() - def client_print(line=""): - line = llm_utils.word_wrap_except_code_blocks(line, self._text_width - 10) - self._log.message(line) - line = textwrap.indent(line, self._chat_prefix, lambda _: True) - print(line, file=self.stdout, flush=True) + stats = self._assistant.query(full_prompt, user_text=arg) - full_prompt = strip_color(full_prompt) - full_prompt = truncate_proportionally(full_prompt) + self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") - self._log.push_chat(arg, full_prompt) - stats = self._assistant.run(full_prompt, client_print) - self._log.pop_chat(stats) - - def do_mark(self, arg): - marks = ["Full", "Partial", "Wrong", "None", "?"] - if arg == None or arg == "": - arg = input(f"mark? (one of {marks}): ") - while arg not in marks: - arg = input(f"mark? (one of {marks}): ") - if arg not in marks: - self.error( - f"answer must be in { ['Full', 'Partial', 'Wrong', '?', 'None'] }" - ) - else: - self._log.add_mark(arg) + def do_renew(self, arg): + """renew + End the current chat dialog and prepare to start a new one. + """ + if self._assistant != None: + self._assistant.close() + self._assistant = None + self.was_chat_or_renew = True + self.message(f"Ready to start a new dialog.") def do_config(self, arg): args = arg.split() @@ -607,116 +619,168 @@ def do_config(self, arg): except TraitError as e: self.error(f"{e}") - # Get the documentation and source code (if available) for any function or method visible in the current frame. The argument to info can be the name of the function or an expression of the form `obj.method_name` to see the information for the method_name method of object obj.", - def _make_assistant(self): - def info(value): - """ - { - "name": "info", - "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", - "parameters": { - "type": "object", - "properties": { - "value": { - "type": "string", - "description": "The reference to get the information for." - } - }, - "required": [ "value" ] - } + instruction_prompt = self._ip_instructions() + + if chatdbg_config.take_the_wheel: + functions = [self.debug, self.info] + if self._supports_flow: + functions += [self.slice] + else: + functions = [] + + self._assistant = Assistant( + instruction_prompt, + model=chatdbg_config.model, + debug=chatdbg_config.debug, + functions=functions, + stream=chatdbg_config.stream, + listeners=[ + ChatAssistantClient( + self.stdout, + self.prompt, + self._chat_prefix, + self._text_width, + stream=chatdbg_config.stream, + ), + self._log, + ], + ) + + ### Callbacks for LLM + + def info(self, value): + """ + { + "name": "info", + "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", + "parameters": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The reference to get the information for." + } + }, + "required": [ "value" ] } - """ - command = f"info {value}" - result = self._capture_onecmd(command) - self.message( - self._format_history_entry((command, result), indent=self._chat_prefix) - ) - result = strip_color(result) - self._log.function(command, result) - return truncate_proportionally(result, top_proportion=1) - - def debug(command): - """ - { - "name": "debug", - "description": "Run a pdb command and get the response.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The pdb command to run." - } - }, - "required": [ "command" ] - } + } + """ + command = f"info {value}" + result = self._capture_onecmd(command) + return command, truncate_proportionally(result, top_proportion=1) + + def debug(self, command): + """ + { + "name": "debug", + "description": "Run a pdb command and get the response.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The pdb command to run." + } + }, + "required": [ "command" ] } - """ - cmd = command if command != "list" else "ll" - result = self._capture_onecmd(cmd) + } + """ + cmd = command if command != "list" else "ll" + # old_curframe = self.curframe + result = self._capture_onecmd(cmd) - self.message( - self._format_history_entry((command, result), indent=self._chat_prefix) - ) + # help the LLM know where it is... + # if old_curframe != self.curframe: + # result += strip_color(self._stack_prompt()) - result = strip_color(result) - self._log.function(command, result) - - # help the LLM know where it is... - result += strip_color(self._stack_prompt()) - return truncate_proportionally(result, maxlen=8000, top_proportion=0.9) - - def slice(name): - """ - { - "name": "slice", - "description": "Return the code to compute a global variable used in the current frame", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The variable to look at." - } - }, - "required": [ "name" ] - } + return command, truncate_proportionally(result, maxlen=8000, top_proportion=0.9) + + def slice(self, name): + """ + { + "name": "slice", + "description": "Return the code to compute a global variable used in the current frame", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The variable to look at." + } + }, + "required": [ "name" ] } + } + """ + command = f"slice {name}" + result = self._capture_onecmd(command) + return command, truncate_proportionally(result, top_proportion=0.5) - """ - command = f"slice {name}" - result = self._capture_onecmd(command) - self.message( - self._format_history_entry((command, result), indent=self._chat_prefix) - ) - result = strip_color(result) - self._log.function(command, result) - return truncate_proportionally(result, top_proportion=0.5) + ############################################################### - self._clear_history() - instruction_prompt = pdb_instructions( - self._supports_flow, chatdbg_config.take_the_wheel + +class ChatAssistantClient(BaseAssistantListener): + def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): + self.out = out + self.debugger_prompt = debugger_prompt + self.chat_prefix = chat_prefix + self.width = width + self._assistant = None + self._stream = stream + + # Call backs + + def on_begin_query(self, prompt, user_text): + pass + + def on_end_query(self, stats): + pass + + def _print(self, text, **kwargs): + print( + textwrap.indent(text, self.chat_prefix, lambda _: True), + file=self.out, + **kwargs, ) - self._log.instructions(instruction_prompt) + def on_warn(self, text): + self._print(textwrap.indent(text, "*** ")) - if not chatdbg_config.model in _valid_models: + def on_fail(self, text): + self._print(textwrap.indent(text, "*** ")) + + def on_begin_stream(self): + self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) + self._at_start = True + + def on_stream_delta(self, text): + if self._at_start: + self._at_start = False print( - f"'{chatdbg_config.model}' is not a valid OpenAI model. Choose from: {_valid_models}." + self._stream_wrapper.append("\n(Message) ", False), + end="", + flush=True, + file=self.out, ) - sys.exit(0) - - self._assistant = Assistant( - "ChatDBG", - instruction_prompt, - model=chatdbg_config.model, - debug=chatdbg_config.debug, + print( + self._stream_wrapper.append(text, False), end="", flush=True, file=self.out ) - if chatdbg_config.take_the_wheel: - self._assistant.add_function(debug) - self._assistant.add_function(info) + def on_end_stream(self): + print(self._stream_wrapper.flush(), end="", flush=True, file=self.out) - if self._supports_flow: - self._assistant.add_function(slice) + def on_response(self, text): + if not self._stream and text != None: + text = word_wrap_except_code_blocks( + text, self.width - len(self.chat_prefix) + ) + self._print(text) + + def on_function_call(self, call, result): + if result and len(result) > 0: + entry = f"{self.debugger_prompt}{call}\n{result}" + else: + entry = f"{self.debugger_prompt}{call}" + self._print(entry) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py deleted file mode 100644 index 3761854..0000000 --- a/src/chatdbg/ipdb_util/chatlog.py +++ /dev/null @@ -1,125 +0,0 @@ -import sys -import uuid -from datetime import datetime -from io import StringIO - -import yaml - -from .config import Chat - - -class CopyingTextIOWrapper: - """ - File wrapper that will stash a copy of everything written. - """ - - def __init__(self, file): - self.file = file - self.buffer = StringIO() - - def write(self, data): - self.buffer.write(data) - return self.file.write(data) - - def getvalue(self): - return self.buffer.getvalue() - - def getfile(self): - return self.file - - def __getattr__(self, attr): - # Delegate attribute access to the file object - return getattr(self.file, attr) - - -class ChatDBGLog: - def __init__(self, config: Chat): - self.steps = [] - self.meta = { - "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "command_line": " ".join(sys.argv), - "uid": str(uuid.uuid4()), - "config": config.to_json(), - "mark": "?", - } - self.log = config.log - self._instructions = "" - self.stdout_wrapper = CopyingTextIOWrapper(sys.stdout) - self.stderr_wrapper = CopyingTextIOWrapper(sys.stderr) - sys.stdout = self.stdout_wrapper - sys.stderr = self.stdout_wrapper - self.chat_step = None - self.mark = "?" - - def add_mark(self, value): - if value not in ["Fix", "Partial", "None", "?"]: - print(f"answer must be in { ['Fix', 'Partial', 'None', '?'] }") - else: - self.meta["mark"] = value - - def total(self, key): - return sum( - [x["stats"][key] for x in self.steps if x["output"]["type"] == "chat"] - ) - - def dump(self): - self.meta["total_tokens"] = self.total("tokens") - self.meta["total_time"] = self.total("time") - self.meta["total_cost"] = self.total("cost") - - full_json = [ - { - "meta": self.meta, - "steps": self.steps, - "instructions": self._instructions, - "stdout": self.stdout_wrapper.getvalue(), - "stderr": self.stderr_wrapper.getvalue(), - } - ] - - print(f"*** Write ChatDBG log to {self.log}") - with open(self.log, "a") as file: - yaml.dump(full_json, file, default_flow_style=False) - - def instructions(self, instructions): - self._instructions = instructions - - def user_command(self, line, output): - if self.chat_step != None: - x = self.chat_step - self.chat_step = None - else: - x = {"input": line, "output": {"type": "text", "output": output}} - self.steps.append(x) - - def push_chat(self, line, full_prompt): - self.chat_step = { - "input": line, - "full_prompt": full_prompt, - "output": {"type": "chat", "outputs": []}, - "stats": {"tokens": 0, "cost": 0, "time": 0}, - } - - def pop_chat(self, stats): - self.chat_step["stats"] = stats - - def message(self, text): - self.chat_step["output"]["outputs"].append({"type": "text", "output": text}) - - def function(self, line, output): - x = { - "type": "call", - "input": line, - "output": {"type": "text", "output": output}, - } - self.chat_step["output"]["outputs"].append(x) - - -# Custom representer for literal scalar representation -def literal_presenter(dumper, data): - if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - -yaml.add_representer(str, literal_presenter) diff --git a/src/chatdbg/ipdb_util/capture.py b/src/chatdbg/util/capture.py similarity index 67% rename from src/chatdbg/ipdb_util/capture.py rename to src/chatdbg/util/capture.py index 21c75ba..c4fcf53 100644 --- a/src/chatdbg/ipdb_util/capture.py +++ b/src/chatdbg/util/capture.py @@ -30,3 +30,27 @@ def read(self, *args, **kwargs): def get_captured_input(self): return self.capture_buffer.getvalue() + + +class CaptureOutput: + """ + File wrapper that will stash a copy of everything written. + """ + + def __init__(self, file): + self.file = file + self.buffer = StringIO() + + def write(self, data): + self.buffer.write(data) + return self.file.write(data) + + def getvalue(self): + return self.buffer.getvalue() + + def getfile(self): + return self.file + + def __getattr__(self, attr): + # Delegate attribute access to the file object + return getattr(self.file, attr) diff --git a/src/chatdbg/util/chatlog.py b/src/chatdbg/util/chatlog.py new file mode 100644 index 0000000..7e65e79 --- /dev/null +++ b/src/chatdbg/util/chatlog.py @@ -0,0 +1,149 @@ +import sys +import uuid +from datetime import datetime + +import yaml + +from chatdbg.util.capture import CaptureOutput + +from ..assistant.listeners import BaseAssistantListener +from ..util.text import word_wrap_except_code_blocks + + +class ChatDBGLog(BaseAssistantListener): + + def __init__(self, log_filename, config, capture_streams=True): + self._log_filename = log_filename + self.config = config + if capture_streams: + self._stdout_wrapper = CaptureOutput(sys.stdout) + self._stderr_wrapper = CaptureOutput(sys.stderr) + sys.stdout = self._stdout_wrapper + sys.stderr = self._stdout_wrapper + else: + self._stderr_wrapper = None + self._stderr_wrapper = None + + self._log = self._make_log() + self._current_chat = None + + def _make_log(self): + meta = { + "time": datetime.now(), + "command_line": " ".join(sys.argv), + "uid": str(uuid.uuid4()), + "config": self.config, + } + return { + "steps": [], + "meta": meta, + "instructions": None, + "stdout": self._stdout_wrapper.getvalue(), + "stderr": self._stderr_wrapper.getvalue(), + } + + def _dump(self): + log = self._log + + def total(key): + return sum( + x["stats"][key] + for x in log["steps"] + if x["output"]["type"] == "chat" and "stats" in x["output"] + ) + + log["meta"]["total_tokens"] = total("tokens") + log["meta"]["total_time"] = total("time") + log["meta"]["total_cost"] = total("cost") + + print(f"*** Writing ChatDBG dialog log to {self._log_filename}") + + with open(self._log_filename, "a") as file: + + def literal_presenter(dumper, data): + if "\n" in data: + return dumper.represent_scalar( + "tag:yaml.org,2002:str", data, style="|" + ) + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + yaml.add_representer(str, literal_presenter) + yaml.dump([log], file, default_flow_style=False, indent=2) + + def on_begin_dialog(self, instructions): + log = self._log + assert log != None + log["instructions"] = instructions + + def on_end_dialog(self): + if self._log != None: + self._dump() + self._log = self._make_log() + + def on_begin_query(self, prompt, extra): + log = self._log + assert log != None + assert self._current_chat == None + self._current_chat = { + "input": extra, + "prompt": prompt, + "output": {"type": "chat", "outputs": []}, + } + + def on_end_query(self, stats): + log = self._log + assert log != None + assert self._current_chat != None + log["steps"] += [self._current_chat] + log["stats"] = stats + self._current_chat = None + + def _post(self, text, kind): + log = self._log + assert log != None + if self._current_chat != None: + self._current_chat["output"]["outputs"].append( + {"type": "text", "output": f"*** {kind}: {text}"} + ) + else: + log["steps"].append( + { + "type": "call", + "input": f"*** {kind}", + "output": {"type": "text", "output": text}, + } + ) + + def on_warn(self, text): + self._post(text, "Warning") + + def on_fail(self, text): + self._post(text, "Failure") + + def on_response(self, text): + log = self._log + assert log != None + assert self._current_chat != None + text = word_wrap_except_code_blocks(text) + self._current_chat["output"]["outputs"].append({"type": "text", "output": text}) + + def on_function_call(self, call, result): + log = self._log + assert log != None + if self._current_chat != None: + self._current_chat["output"]["outputs"].append( + { + "type": "call", + "input": call, + "output": {"type": "text", "output": result}, + } + ) + else: + log["steps"].append( + { + "type": "call", + "input": call, + "output": {"type": "text", "output": result}, + } + ) diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/util/config.py similarity index 50% rename from src/chatdbg/ipdb_util/config.py rename to src/chatdbg/util/config.py index 60cce10..df08178 100644 --- a/src/chatdbg/ipdb_util/config.py +++ b/src/chatdbg/util/config.py @@ -1,10 +1,10 @@ import os -from traitlets import Bool, Int, Unicode +from traitlets import Bool, Int, TraitError, Unicode from traitlets.config import Configurable -def chat_get_env(option_name, default_value): +def _chatdbg_get_env(option_name, default_value): env_name = "CHATDBG_" + option_name.upper() v = os.getenv(env_name, str(default_value)) if type(default_value) == int: @@ -15,35 +15,50 @@ def chat_get_env(option_name, default_value): return v -class Chat(Configurable): +class ChatDBGConfig(Configurable): model = Unicode( - chat_get_env("model", "gpt-4-1106-preview"), help="The OpenAI model" + _chatdbg_get_env("model", "gpt-4-1106-preview"), help="The LLM model" ).tag(config=True) - # model = Unicode(default_value='gpt-3.5-turbo-1106', help="The OpenAI model").tag(config=True) - debug = Bool(chat_get_env("debug", False), help="Log OpenAI calls").tag(config=True) - log = Unicode(chat_get_env("log", "log.yaml"), help="The log file").tag(config=True) - tag = Unicode(chat_get_env("tag", ""), help="Any extra info for log file").tag( + + debug = Bool(_chatdbg_get_env("debug", False), help="Log LLM calls").tag( + config=True + ) + + log = Unicode(_chatdbg_get_env("log", "log.yaml"), help="The log file").tag( + config=True + ) + + tag = Unicode(_chatdbg_get_env("tag", ""), help="Any extra info for log file").tag( config=True ) rc_lines = Unicode( - chat_get_env("rc_lines", "[]"), help="lines to run at startup" + _chatdbg_get_env("rc_lines", "[]"), help="lines to run at startup" ).tag(config=True) context = Int( - chat_get_env("context", 5), + _chatdbg_get_env("context", 10), help="lines of source code to show when displaying stacktrace information", ).tag(config=True) + show_locals = Bool( - chat_get_env("show_locals", True), help="show local var values in stacktrace" + _chatdbg_get_env("show_locals", True), + help="show local var values in stacktrace", ).tag(config=True) + show_libs = Bool( - chat_get_env("show_libs", False), help="show library frames in stacktrace" + _chatdbg_get_env("show_libs", False), help="show library frames in stacktrace" ).tag(config=True) + show_slices = Bool( - chat_get_env("show_slices", True), help="support the `slice` command" + _chatdbg_get_env("show_slices", True), help="support the `slice` command" ).tag(config=True) + take_the_wheel = Bool( - chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" + _chatdbg_get_env("take_the_wheel", True), help="Let LLM take the wheel" + ).tag(config=True) + + stream = Bool( + _chatdbg_get_env("stream", False), help="Stream the response at it arrives" ).tag(config=True) def to_json(self): @@ -59,4 +74,8 @@ def to_json(self): "show_libs": self.show_libs, "show_slices": self.show_slices, "take_the_wheel": self.take_the_wheel, + "stream": self.stream, } + + +chatdbg_config: ChatDBGConfig = ChatDBGConfig() diff --git a/src/chatdbg/ipdb_util/locals.py b/src/chatdbg/util/locals.py similarity index 100% rename from src/chatdbg/ipdb_util/locals.py rename to src/chatdbg/util/locals.py diff --git a/src/chatdbg/ipdb_util/prompts.py b/src/chatdbg/util/prompts.py similarity index 99% rename from src/chatdbg/ipdb_util/prompts.py rename to src/chatdbg/util/prompts.py index 72ba2b3..f2c72bd 100644 --- a/src/chatdbg/ipdb_util/prompts.py +++ b/src/chatdbg/util/prompts.py @@ -1,3 +1,5 @@ +import os + _intro = f"""\ You are a debugging assistant. You will be given a Python stack trace for an error and answer questions related to the root cause of the error. @@ -12,6 +14,7 @@ contribute to the error. """ + _info_function = """\ Call the `info` function to get the documentation and source code for any variable, function, package, class, method reference, field reference, or diff --git a/src/chatdbg/util/streamwrap.py b/src/chatdbg/util/streamwrap.py new file mode 100644 index 0000000..0f95e08 --- /dev/null +++ b/src/chatdbg/util/streamwrap.py @@ -0,0 +1,48 @@ +import textwrap +import re +import sys +from .text import word_wrap_except_code_blocks + + +class StreamingTextWrapper: + + def __init__(self, indent=" ", width=80): + self._buffer = "" # the raw text so far + self._wrapped = "" # the successfully wrapped text do far + self._pending = ( + "" # the part after the last space in buffer -- has not been wrapped yet + ) + self._indent = indent + self._width = width - len(indent) + + def append(self, text, flush=False): + if flush: + self._buffer += self._pending + text + self._pending = "" + else: + text_bits = re.split(r"(\s+)", self._pending + text) + self._pending = text_bits[-1] + self._buffer += "".join(text_bits[0:-1]) + + wrapped = word_wrap_except_code_blocks(self._buffer) + wrapped = textwrap.indent(wrapped, self._indent, lambda _: True) + wrapped_delta = wrapped[len(self._wrapped) :] + self._wrapped = wrapped + return wrapped_delta + + def flush(self): + if len(self._buffer) > 0: + result = self.append("\n", flush=True) + else: + result = self.append("", flush=True) + self._buffer = "" + self._wrapped = "" + return result + + +if __name__ == "__main__": + s = StreamingTextWrapper(3, 20) + for x in sys.argv[1:]: + y = s.append(" " + x) + print(y, end="", flush=True) + print(s.flush()) diff --git a/src/chatdbg/ipdb_util/text.py b/src/chatdbg/util/text.py similarity index 82% rename from src/chatdbg/ipdb_util/text.py rename to src/chatdbg/util/text.py index 538293d..9d31e0d 100644 --- a/src/chatdbg/ipdb_util/text.py +++ b/src/chatdbg/util/text.py @@ -3,6 +3,7 @@ import inspect import numbers import numpy as np +import textwrap def make_arrow(pad): @@ -116,3 +117,28 @@ def truncate_proportionally(text, maxlen=32000, top_proportion=0.5): post = max(0, maxlen - 3 - pre) return text[:pre] + "..." + text[len(text) - post :] return text + + +def word_wrap_except_code_blocks(text: str, width: int = 80) -> str: + """ + Wraps text except for code blocks for nice terminal formatting. + + Splits the text into paragraphs and wraps each paragraph, + except for paragraphs that are inside of code blocks denoted + by ` ``` `. Returns the updated text. + + Args: + text (str): The text to wrap. + width (int): The width of the lines to wrap at, passed to `textwrap.fill`. + + Returns: + The wrapped text. + """ + blocks = text.split("```") + for i in range(len(blocks)): + if i % 2 == 0: + paras = blocks[i].split("\n") + wrapped = [textwrap.fill(para, width=width) for para in paras] + blocks[i] = "\n".join(wrapped) + + return "```".join(blocks)