From 5906b64cea248c3cd41d9beb70b976b4ca1be4ef Mon Sep 17 00:00:00 2001 From: Stephen Freund Date: Fri, 22 Mar 2024 15:19:15 -0400 Subject: [PATCH 01/17] stipud --- src/chatdbg/assistant/assistant-old.py | 243 ++++++++++++++++++++ src/chatdbg/assistant/assistant.py | 306 ++++++++++++------------- src/chatdbg/chatdbg_pdb.py | 50 ++-- src/chatdbg/ipdb_util/logging.py | 11 +- src/chatdbg/ipdb_util/printer.py | 53 +++++ src/chatdbg/ipdb_util/streamwrap.py | 43 ++++ 6 files changed, 517 insertions(+), 189 deletions(-) create mode 100644 src/chatdbg/assistant/assistant-old.py create mode 100644 src/chatdbg/ipdb_util/printer.py create mode 100644 src/chatdbg/ipdb_util/streamwrap.py diff --git a/src/chatdbg/assistant/assistant-old.py b/src/chatdbg/assistant/assistant-old.py new file mode 100644 index 0000000..4366aa9 --- /dev/null +++ b/src/chatdbg/assistant/assistant-old.py @@ -0,0 +1,243 @@ +import atexit +import inspect +import json +import time +import sys + +import llm_utils +from openai import * +from pydantic import BaseModel + + +class Assistant: + """ + An Assistant is a wrapper around OpenAI's assistant API. Example usage: + + assistant = Assistant("Assistant Name", instructions, + model='gpt-4-1106-preview', debug=True) + assistant.add_function(my_func) + response = assistant.run(user_prompt) + + Name can be any name you want. + + If debug is True, it will create a log of all messages and JSON responses in + json.txt. + """ + + def __init__( + self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, debug=True + ): + if debug: + self.json = open(f"json.txt", "a") + else: + self.json = None + + try: + self.client = OpenAI(timeout=timeout) + except OpenAIError: + print("*** You need an OpenAI key to use this tool.") + print("*** You can get a key here: https://platform.openai.com/api-keys") + print("*** Set the environment variable OPENAI_API_KEY to your key value.") + sys.exit(-1) + + self.assistants = self.client.beta.assistants + self.threads = self.client.beta.threads + self.functions = dict() + + self.assistant = self.assistants.create( + name=name, instructions=instructions, model=model + ) + self.thread = self.threads.create() + + atexit.register(self._delete_assistant) + + def _delete_assistant(self): + if self.assistant != None: + try: + id = self.assistant.id + response = self.assistants.delete(id) + assert response.deleted + except OSError: + raise + except Exception as e: + print(f"*** Assistant {id} was not deleted ({e}).") + print("*** You can do so at https://platform.openai.com/assistants.") + + def add_function(self, function): + """ + Add a new function to the list of function tools for the assistant. + The function should have the necessary json spec as its pydoc string. + """ + function_json = json.loads(function.__doc__) + try: + name = function_json["name"] + self.functions[name] = function + + tools = [ + {"type": "function", "function": json.loads(function.__doc__)} + for function in self.functions.values() + ] + + self.assistants.update(self.assistant.id, tools=tools) + except OpenAIError as e: + print(f"*** OpenAI Error: {e}") + sys.exit(-1) + + def _make_call(self, tool_call): + name = tool_call.function.name + args = tool_call.function.arguments + + # There is a sketchy case that happens occasionally because + # the API produces a bad call... + try: + args = json.loads(args) + function = self.functions[name] + result = function(**args) + except OSError as e: + result = f"Error: {e}" + except Exception as e: + result = f"Ill-formed function call: {e}" + + return result + + def _print_messages(self, messages, client_print): + client_print() + for i, m in enumerate(messages): + message_text = m.content[0].text.value + if i == 0: + message_text = "(Message) " + message_text + client_print(message_text) + + def _wait_on_run(self, run, thread, client_print): + try: + while run.status == "queued" or run.status == "in_progress": + run = self.threads.runs.retrieve( + thread_id=thread.id, + run_id=run.id, + ) + time.sleep(0.5) + return run + finally: + if run.status == "in_progress": + client_print("Cancelling message that's in progress.") + self.threads.runs.cancel(thread_id=thread.id, run_id=run.id) + + def run(self, prompt, client_print=print): + """ + Give the prompt to the assistant and get the response, which may included + intermediate function calls. + All output is printed to the given file. + """ + start_time = time.perf_counter() + + try: + if self.assistant == None: + return { + "tokens": run.usage.total_tokens, + "prompt_tokens": run.usage.prompt_tokens, + "completion_tokens": run.usage.completion_tokens, + "model": self.assistant.model, + "cost": cost, + } + + assert len(prompt) <= 32768 + + message = self.threads.messages.create( + thread_id=self.thread.id, role="user", content=prompt + ) + self._log(message) + + last_printed_message_id = message.id + + run = self.threads.runs.create( + thread_id=self.thread.id, assistant_id=self.assistant.id + ) + self._log(run) + + run = self._wait_on_run(run, self.thread, client_print) + self._log(run) + + while run.status == "requires_action": + messages = self.threads.messages.list( + thread_id=self.thread.id, after=last_printed_message_id, order="asc" + ) + + mlist = list(messages) + if len(mlist) > 0: + self._print_messages(mlist, client_print) + last_printed_message_id = mlist[-1].id + client_print() + + outputs = [] + for tool_call in run.required_action.submit_tool_outputs.tool_calls: + output = self._make_call(tool_call) + self._log(output) + outputs += [{"tool_call_id": tool_call.id, "output": output}] + + try: + run = self.threads.runs.submit_tool_outputs( + thread_id=self.thread.id, run_id=run.id, tool_outputs=outputs + ) + self._log(run) + except OSError as e: + self._log(run, f"FAILED to submit tool call results: {e}") + raise + except Exception as e: + self._log(run, f"FAILED to submit tool call results: {e}") + + run = self._wait_on_run(run, self.thread, client_print) + self._log(run) + + if run.status == "failed": + message = f"\n**Internal Failure ({run.last_error.code}):** {run.last_error.message}" + client_print(message) + self._log(run) + sys.exit(-1) + + messages = self.threads.messages.list( + thread_id=self.thread.id, after=last_printed_message_id, order="asc" + ) + self._print_messages(messages, client_print) + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + + cost = llm_utils.calculate_cost( + run.usage.prompt_tokens, + run.usage.completion_tokens, + self.assistant.model, + ) + client_print() + client_print(f"[Cost: ~${cost:.2f} USD]") + + return { + "tokens": run.usage.total_tokens, + "prompt_tokens": run.usage.prompt_tokens, + "completion_tokens": run.usage.completion_tokens, + "model": self.assistant.model, + "cost": cost, + "time": elapsed_time, + "thread.id": self.thread.id, + "run.id": run.id, + "assistant.id": self.assistant.id, + } + except OpenAIError as e: + client_print(f"*** OpenAI Error: {e}") + sys.exit(-1) + + def _log(self, obj, title=""): + if self.json != None: + stack = inspect.stack() + caller_frame_record = stack[1] + lineno, function = caller_frame_record[2:4] + loc = f"{function}:{lineno}" + + print("-" * 70, file=self.json) + print(f"{loc} {title}", file=self.json) + if isinstance(obj, BaseModel): + json_obj = json.loads(obj.model_dump_json()) + else: + json_obj = obj + print(f"\n{json.dumps(json_obj, indent=2)}\n", file=self.json) + self.json.flush() + return obj diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 3bfe65d..205306e 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -1,7 +1,6 @@ import atexit -import inspect -import json import textwrap +import json import time import sys @@ -9,6 +8,25 @@ from openai import * from pydantic import BaseModel +class AssistantPrinter: + def text_delta(self, text=''): + print(text, flush=True, end='') + + def text_message(self, text=''): + print(text, flush=True) + + def log(self, json_obj): + pass + + def fail(self, message='Failed'): + print() + print(textwrap.wrap(message, width=70, initial_indent='*** ')) + sys.exit(1) + + def warn(self, message='Warning'): + print() + print(textwrap.wrap(message, width=70, initial_indent='*** ')) + class Assistant: """ @@ -25,28 +43,17 @@ class Assistant: json.txt. """ - # TODO: At some point, if we unify the argument parsing, we should just have this take args. def __init__( - self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, debug=True - ): - if debug: - self.json = open(f"json.txt", "a") - else: - self.json = None - + self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, + printer = AssistantPrinter()): + self.printer = printer try: self.client = OpenAI(timeout=timeout) except OpenAIError: - print( - textwrap.dedent( - """\ - You need an OpenAI key to use this tool. - You can get a key here: https://platform.openai.com/api-keys - Set the environment variable OPENAI_API_KEY to your key value. - """ - ) - ) - sys.exit(-1) + self.printer.fail("""\ + You need an OpenAI key to use this tool. + You can get a key here: https://platform.openai.com/api-keys. + Set the environment variable OPENAI_API_KEY to your key value.""") self.assistants = self.client.beta.assistants self.threads = self.client.beta.threads @@ -55,26 +62,20 @@ def __init__( self.assistant = self.assistants.create( name=name, instructions=instructions, model=model ) - - self._log(self.assistant) + self.thread = self.threads.create() atexit.register(self._delete_assistant) - self.thread = self.threads.create() - self._log(self.thread) - def _delete_assistant(self): if self.assistant != None: try: id = self.assistant.id response = self.assistants.delete(id) - self._log(response) assert response.deleted except OSError: raise except Exception as e: - print(f"Assistant {id} was not deleted ({e}).") - print("You can do so at https://platform.openai.com/assistants.") + self.printer.warn(f"Assistant {id} was not deleted ({e}). You can do so at https://platform.openai.com/assistants.") def add_function(self, function): """ @@ -82,7 +83,6 @@ def add_function(self, function): The function should have the necessary json spec as its pydoc string. """ function_json = json.loads(function.__doc__) - assert "name" in function_json, "Bad JSON in pydoc for function tool." try: name = function_json["name"] self.functions[name] = function @@ -92,11 +92,9 @@ def add_function(self, function): for function in self.functions.values() ] - assistant = self.assistants.update(self.assistant.id, tools=tools) - self._log(assistant) + self.assistants.update(self.assistant.id, tools=tools) except OpenAIError as e: - print(f"*** OpenAI Error: {e}") - sys.exit(-1) + self.printer.fail(f"*** OpenAI Error: {e}") def _make_call(self, tool_call): name = tool_call.function.name @@ -110,150 +108,128 @@ def _make_call(self, tool_call): result = function(**args) except OSError as e: result = f"Error: {e}" - # raise except Exception as e: - result = f"Ill-formed function call ({e})\n" + result = f"Ill-formed function call: {e}" return result - def _print_messages(self, messages, client_print): - client_print() - for i, m in enumerate(messages): - message_text = m.content[0].text.value - if i == 0: - message_text = "(Message) " + message_text - client_print(message_text) - - def _wait_on_run(self, run, thread, client_print): - try: - while run.status == "queued" or run.status == "in_progress": - run = self.threads.runs.retrieve( - thread_id=thread.id, - run_id=run.id, - ) - time.sleep(0.5) - return run - finally: - if run.status == "in_progress": - client_print("Cancelling message that's in progress.") - self.threads.runs.cancel(thread_id=thread.id, run_id=run.id) - - def run(self, prompt, client_print=print): + def drain_stream(self, stream): + run = None + for event in stream: + self.printer.log(event) + if event.event == 'thread.run.completed': + run = event.data + if event.event == 'thread.message.delta': + self.printer.text_delta(event.data.delta.content[0].text.value) + if event.event == 'thread.message.completed': + self.printer.text_message(event.data.content[0].text.value) + elif event.event == 'thread.run.requires_action': + r = event.data + if r.status == "requires_action": + outputs = [] + for tool_call in r.required_action.submit_tool_outputs.tool_calls: + output = self._make_call(tool_call) + outputs += [{"tool_call_id": tool_call.id, "output": output}] + + try: + new_stream = self.threads.runs.submit_tool_outputs( + thread_id=self.thread.id, run_id=r.id, tool_outputs=outputs, stream=True + ) + return self.drain_stream(new_stream) + except OSError as e: + raise + except Exception as e: + # silent failure because the tool call submit biffed. Not muchw e can do + pass + elif event.event == 'thread.run.failed': + run = event.data + self.printer.fail(f"*** Internal Failure ({run.last_error.code}): {run.last_error.message}") + elif event.event == 'error': + self.printer.fail(f"*** Internal Failure:** {event.data}") + return run + + def run(self, prompt): """ Give the prompt to the assistant and get the response, which may included intermediate function calls. All output is printed to the given file. """ - start_time = time.perf_counter() - - try: - if self.assistant == None: - return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, - "model": self.assistant.model, - "cost": cost, - } - - assert len(prompt) <= 32768 - - message = self.threads.messages.create( - thread_id=self.thread.id, role="user", content=prompt - ) - self._log(message) - - last_printed_message_id = message.id - - run = self.threads.runs.create( - thread_id=self.thread.id, assistant_id=self.assistant.id - ) - self._log(run) - - run = self._wait_on_run(run, self.thread, client_print) - self._log(run) - - while run.status == "requires_action": - messages = self.threads.messages.list( - thread_id=self.thread.id, after=last_printed_message_id, order="asc" - ) - - mlist = list(messages) - if len(mlist) > 0: - self._print_messages(mlist, client_print) - last_printed_message_id = mlist[-1].id - client_print() - - outputs = [] - for tool_call in run.required_action.submit_tool_outputs.tool_calls: - output = self._make_call(tool_call) - self._log(output) - outputs += [{"tool_call_id": tool_call.id, "output": output}] - - try: - run = self.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, run_id=run.id, tool_outputs=outputs - ) - self._log(run) - except OSError as e: - self._log(run, f"FAILED to submit tool call results: {e}") - raise - except Exception as e: - self._log(run, f"FAILED to submit tool call results: {e}") - - run = self._wait_on_run(run, self.thread, client_print) - self._log(run) - - if run.status == "failed": - message = f"\n**Internal Failure ({run.last_error.code}):** {run.last_error.message}" - client_print(message) - self._log(run) - sys.exit(-1) - - messages = self.threads.messages.list( - thread_id=self.thread.id, after=last_printed_message_id, order="asc" - ) - self._print_messages(messages, client_print) - - end_time = time.perf_counter() - elapsed_time = end_time - start_time - - cost = llm_utils.calculate_cost( - run.usage.prompt_tokens, - run.usage.completion_tokens, - self.assistant.model, - ) - client_print() - client_print(f"[Cost: ~${cost:.2f} USD]") + if self.assistant == None: return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, + "tokens": 0, + "prompt_tokens": 0, + "completion_tokens": 0, "model": self.assistant.model, - "cost": cost, - "time": elapsed_time, - "thread.id": self.thread.id, - "run.id": run.id, - "assistant.id": self.assistant.id, + "cost": 0, } - except OpenAIError as e: - client_print(f"*** OpenAI Error: {e}") - sys.exit(-1) - def _log(self, obj, title=""): - if self.json != None: - stack = inspect.stack() - caller_frame_record = stack[1] - lineno, function = caller_frame_record[2:4] - loc = f"{function}:{lineno}" + start_time = time.perf_counter() + + assert len(prompt) <= 32768 + self.threads.messages.create( + thread_id=self.thread.id, role="user", content=prompt + ) + with self.threads.runs.create( + thread_id=self.thread.id, + assistant_id=self.assistant.id, + stream=True + ) as stream: + run = self.drain_stream(stream) + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + + cost = llm_utils.calculate_cost( + run.usage.prompt_tokens, + run.usage.completion_tokens, + self.assistant.model, + ) + return { + "tokens": run.usage.total_tokens, + "prompt_tokens": run.usage.prompt_tokens, + "completion_tokens": run.usage.completion_tokens, + "model": self.assistant.model, + "cost": cost, + "time": elapsed_time, + "thread.id": self.thread.id, + "thread": self.thread, + "run.id": run.id, + "run": run, + "assistant.id": self.assistant.id, + } + + +if __name__ == '__main__': + def weather(location): + """ + { + "name": "get_weather", + "description": "Determine weather in my location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + } + """ + return "Sunny and 72 degrees." - print("-" * 70, file=self.json) - print(f"{loc} {title}", file=self.json) - if isinstance(obj, BaseModel): - json_obj = json.loads(obj.model_dump_json()) - else: - json_obj = obj - print(f"\n{json.dumps(json_obj, indent=2)}\n", file=self.json) - self.json.flush() - return obj + a = Assistant("Test", "You generate text.") + a.add_function(weather) + x = a.run("What's the weather in Boston?") + print(x) \ No newline at end of file diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index cae7bfe..6682eea 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -229,13 +229,12 @@ def message(self, msg) -> None: Override to remove tabs for messages so we can indent them. """ return super().message(str(msg).expandtabs()) - + def error(self, msg) -> None: """ Override to remove tabs for messages so we can indent them. """ return super().error(str(msg).expandtabs()) - # return super().error('If the name is undefined, be sure you are in the right frame. Use up and down to do that, and then print the variable again'.expandtabs()) def _capture_onecmd(self, line): """ @@ -563,31 +562,15 @@ def do_chat(self, arg): if self._assistant == None: self._make_assistant() - def client_print(line=""): - line = llm_utils.word_wrap_except_code_blocks(line, self._text_width - 10) - self._log.message(line) - line = textwrap.indent(line, self._chat_prefix, lambda _: True) - print(line, file=self.stdout, flush=True) full_prompt = strip_color(full_prompt) full_prompt = truncate_proportionally(full_prompt) self._log.push_chat(arg, full_prompt) - stats = self._assistant.run(full_prompt, client_print) + stats = self._assistant.run(full_prompt) self._log.pop_chat(stats) - def do_mark(self, arg): - marks = ["Full", "Partial", "Wrong", "None", "?"] - if arg == None or arg == "": - arg = input(f"mark? (one of {marks}): ") - while arg not in marks: - arg = input(f"mark? (one of {marks}): ") - if arg not in marks: - self.error( - f"answer must be in { ['Full', 'Partial', 'Wrong', '?', 'None'] }" - ) - else: - self._log.add_mark(arg) + self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") def do_config(self, arg): args = arg.split() @@ -711,7 +694,7 @@ def slice(name): "ChatDBG", instruction_prompt, model=chatdbg_config.model, - debug=chatdbg_config.debug, + printer=self ) if chatdbg_config.take_the_wheel: @@ -720,3 +703,28 @@ def slice(name): if self._supports_flow: self._assistant.add_function(slice) + + + ############################################################### + + def text_delta(self, text=''): + print(text, file=self.stdout, flush=True, end='') + + def text_message(self, text=''): + line = llm_utils.word_wrap_except_code_blocks(text, self._text_width - 10) + self._log.message(line) + line = textwrap.indent(line, self._chat_prefix, lambda _: True) + print(line, file=self.stdout, flush=True) + + def log(self, json_obj): + if chatdbg_config.debug: + self._log.log(json_obj) + + def fail(self, message='Failed'): + self.print(file=self.stdout) + self.print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) + sys.exit(1) + + def warn(self, message='Warning'): + self.print(file=self.stdout) + self.print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) diff --git a/src/chatdbg/ipdb_util/logging.py b/src/chatdbg/ipdb_util/logging.py index 80626aa..28ab1e6 100644 --- a/src/chatdbg/ipdb_util/logging.py +++ b/src/chatdbg/ipdb_util/logging.py @@ -40,13 +40,14 @@ def __init__(self, config: Chat): "config": config.to_json(), "mark": "?", } - self.log = config.log + self.log_file = config.log self._instructions = "" self.stdout_wrapper = CopyingTextIOWrapper(sys.stdout) self.stderr_wrapper = CopyingTextIOWrapper(sys.stderr) sys.stdout = self.stdout_wrapper sys.stderr = self.stdout_wrapper self.chat_step = None + self.events = [ ] self.mark = "?" def add_mark(self, value): @@ -72,13 +73,17 @@ def dump(self): "instructions": self._instructions, "stdout": self.stdout_wrapper.getvalue(), "stderr": self.stderr_wrapper.getvalue(), + "events" : self.events } ] - print(f"*** Write ChatDBG log to {self.log}") - with open(self.log, "a") as file: + print(f"*** Write ChatDBG log to {self.log_file}") + with open(self.log_file, "a") as file: yaml.dump(full_json, file, default_flow_style=False) + def log(self, event_json): + self.events += [ event_json ] + def instructions(self, instructions): self._instructions = instructions diff --git a/src/chatdbg/ipdb_util/printer.py b/src/chatdbg/ipdb_util/printer.py new file mode 100644 index 0000000..39afd3a --- /dev/null +++ b/src/chatdbg/ipdb_util/printer.py @@ -0,0 +1,53 @@ +import textwrap +from ..assistant.assistant import AssistantPrinter +from ..chatdbg_pdb import ChatDBG +import sys + +class Printer(AssistantPrinter): + + def __init__(self, message, error, log): + self._message = message + self._error = error + self._log = log + + def text_delta(self, text): + print(text, flush=True, end=None) + + def text_message(self, text): + print(text, flush=True) + + def log(self, json_obj): + pass + + def fail(self, message): + print() + print(textwrap.wrap(message, width=70, initial_indent='*** ')) + sys.exit(1) + + def warn(self, message): + print() + print(textwrap.wrap(message, width=70, initial_indent='*** ')) + +class StreamingPrinter(AssistantPrinter): + + def __init__(self, message, error): + self.message = message + self.error = error + + def text_delta(self, text): + print(text, flush=True, end=None) + + def text_message(self, text): + print(text, flush=True) + + def log(self, json_obj): + pass + + def fail(self, message): + print() + print(textwrap.wrap(message, width=70, initial_indent='*** ')) + sys.exit(1) + + def warn(self, message): + print() + print(textwrap.wrap(message, width=70, initial_indent='*** ')) diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py new file mode 100644 index 0000000..52816f5 --- /dev/null +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -0,0 +1,43 @@ +import textwrap +import re +import sys +import llm_utils + + +class StreamTextWrapper: + + def __init__(self, indent=3, width=70): + self.buffer = '' + self.wrapped = '' + self.pending = '' + self.indent = indent + self.width = width + + def add(self, text, flush=False): + unwrapped = self.buffer + if flush: + unwrapped += self.pending + text + self.pending = '' + else: + text_bits = re.split(r'(\s+)', self.pending + text) + self.pending = text_bits[-1] + unwrapped += (''.join(text_bits[0:-1])) + + # print('---', unwrapped, '---', self.pending) + wrapped = word_wrap_except_code_blocks(unwrapped, self.width) + wrapped = textwrap.indent(wrapped, ' ' * self.indent, lambda _: True) + printable_part = wrapped[len(self.wrapped):] + self.wrapped = wrapped + return printable_part + + def flush(self): + return self.add('', flush=True) + + + +if __name__ == '__main__': + s = StreamTextWrapper(3,20) + for x in sys.argv[1:]: + y = s.add(' ' + x) + print(y, end='', flush=True) + print(s.flush()) From 66bd66f5310af679aa0b2889255a865d9af28877 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Sat, 23 Mar 2024 12:46:23 -0400 Subject: [PATCH 02/17] streaming text --- src/chatdbg/assistant/assistant.py | 28 ++++++++++++----- src/chatdbg/chatdbg_pdb.py | 30 ++++++++++++------- .../ipdb_util/{logging.py => chatlog.py} | 0 src/chatdbg/ipdb_util/printer.py | 8 ++--- src/chatdbg/ipdb_util/streamwrap.py | 15 +++++----- 5 files changed, 51 insertions(+), 30 deletions(-) rename src/chatdbg/ipdb_util/{logging.py => chatlog.py} (100%) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 205306e..d310104 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -9,10 +9,16 @@ from pydantic import BaseModel class AssistantPrinter: - def text_delta(self, text=''): + def begin_stream(self): + print('\n', flush=True) + + def stream(self, text=''): print(text, flush=True, end='') - def text_message(self, text=''): + def end_stream(self): + print('\n', flush=True) + + def complete_message(self, text=''): print(text, flush=True) def log(self, json_obj): @@ -94,7 +100,7 @@ def add_function(self, function): self.assistants.update(self.assistant.id, tools=tools) except OpenAIError as e: - self.printer.fail(f"*** OpenAI Error: {e}") + self.printer.fail(f"OpenAI Error: {e}") def _make_call(self, tool_call): name = tool_call.function.name @@ -110,19 +116,23 @@ def _make_call(self, tool_call): result = f"Error: {e}" except Exception as e: result = f"Ill-formed function call: {e}" - return result def drain_stream(self, stream): run = None for event in stream: + # print(event) self.printer.log(event) + if event.event == 'thread.run.created': + #self.printer.begin_stream() + pass if event.event == 'thread.run.completed': + self.printer.end_stream() run = event.data if event.event == 'thread.message.delta': - self.printer.text_delta(event.data.delta.content[0].text.value) + self.printer.stream(event.data.delta.content[0].text.value) if event.event == 'thread.message.completed': - self.printer.text_message(event.data.content[0].text.value) + self.printer.complete_message(event.data.content[0].text.value) elif event.event == 'thread.run.requires_action': r = event.data if r.status == "requires_action": @@ -135,6 +145,8 @@ def drain_stream(self, stream): new_stream = self.threads.runs.submit_tool_outputs( thread_id=self.thread.id, run_id=r.id, tool_outputs=outputs, stream=True ) + self.printer.end_stream() + # self.printer.begin_stream() return self.drain_stream(new_stream) except OSError as e: raise @@ -143,9 +155,9 @@ def drain_stream(self, stream): pass elif event.event == 'thread.run.failed': run = event.data - self.printer.fail(f"*** Internal Failure ({run.last_error.code}): {run.last_error.message}") + self.printer.fail(f"Internal Failure ({run.last_error.code}): {run.last_error.message}") elif event.event == 'error': - self.printer.fail(f"*** Internal Failure:** {event.data}") + self.printer.fail(f"Internal Failure:** {event.data}") return run def run(self, prompt): diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 6682eea..679e9e3 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -20,10 +20,11 @@ from .assistant.assistant import Assistant from .ipdb_util.config import Chat -from .ipdb_util.logging import ChatDBGLog, CopyingTextIOWrapper +from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper from .ipdb_util.prompts import pdb_instructions from .ipdb_util.text import * from .ipdb_util.locals import * +from .ipdb_util.streamwrap import StreamTextWrapper _valid_models = [ "gpt-4-turbo-preview", @@ -92,6 +93,7 @@ def __init__(self, *args, **kwargs): chatdbg_config = Chat() sys.stdin = CaptureInput(sys.stdin) + self.wrapper = StreamTextWrapper(indent=self._chat_prefix, width=self._text_width) # Only use flow when we are in jupyter or using stdin in ipython. In both # cases, there will be no python file at the start of argv after the @@ -707,24 +709,32 @@ def slice(name): ############################################################### - def text_delta(self, text=''): - print(text, file=self.stdout, flush=True, end='') + def begin_stream(self): + print(file=self.stdout) - def text_message(self, text=''): + def stream(self, text=''): + print(self.wrapper.add(text), file=self.stdout, flush=True, end='') + + def end_stream(self): + print(file=self.stdout) + + def complete_message(self, text=''): + print(self.wrapper.add('', flush=True), file=self.stdout, flush=True, end='') line = llm_utils.word_wrap_except_code_blocks(text, self._text_width - 10) self._log.message(line) - line = textwrap.indent(line, self._chat_prefix, lambda _: True) - print(line, file=self.stdout, flush=True) + # line = textwrap.indent(line, self._chat_prefix, lambda _: True) + # print(line, file=self.stdout, flush=True) def log(self, json_obj): if chatdbg_config.debug: self._log.log(json_obj) def fail(self, message='Failed'): - self.print(file=self.stdout) - self.print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) + print(file=self.stdout) + print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) + raise Exception() sys.exit(1) def warn(self, message='Warning'): - self.print(file=self.stdout) - self.print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) + print(file=self.stdout) + print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) diff --git a/src/chatdbg/ipdb_util/logging.py b/src/chatdbg/ipdb_util/chatlog.py similarity index 100% rename from src/chatdbg/ipdb_util/logging.py rename to src/chatdbg/ipdb_util/chatlog.py diff --git a/src/chatdbg/ipdb_util/printer.py b/src/chatdbg/ipdb_util/printer.py index 39afd3a..f0d2f95 100644 --- a/src/chatdbg/ipdb_util/printer.py +++ b/src/chatdbg/ipdb_util/printer.py @@ -10,10 +10,10 @@ def __init__(self, message, error, log): self._error = error self._log = log - def text_delta(self, text): + def stream(self, text): print(text, flush=True, end=None) - def text_message(self, text): + def message(self, text): print(text, flush=True) def log(self, json_obj): @@ -34,10 +34,10 @@ def __init__(self, message, error): self.message = message self.error = error - def text_delta(self, text): + def stream(self, text): print(text, flush=True, end=None) - def text_message(self, text): + def message(self, text): print(text, flush=True) def log(self, json_obj): diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py index 52816f5..69a0f42 100644 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -1,12 +1,12 @@ import textwrap import re import sys -import llm_utils +from llm_utils import word_wrap_except_code_blocks class StreamTextWrapper: - def __init__(self, indent=3, width=70): + def __init__(self, indent=' ', width=70): self.buffer = '' self.wrapped = '' self.pending = '' @@ -14,18 +14,17 @@ def __init__(self, indent=3, width=70): self.width = width def add(self, text, flush=False): - unwrapped = self.buffer if flush: - unwrapped += self.pending + text + self.buffer += self.pending + text self.pending = '' else: text_bits = re.split(r'(\s+)', self.pending + text) self.pending = text_bits[-1] - unwrapped += (''.join(text_bits[0:-1])) + self.buffer += (''.join(text_bits[0:-1])) - # print('---', unwrapped, '---', self.pending) - wrapped = word_wrap_except_code_blocks(unwrapped, self.width) - wrapped = textwrap.indent(wrapped, ' ' * self.indent, lambda _: True) + # print('---', self.buffer, '---', self.pending) + wrapped = word_wrap_except_code_blocks(self.buffer, self.width) + wrapped = textwrap.indent(wrapped, self.indent, lambda _: True) printable_part = wrapped[len(self.wrapped):] self.wrapped = wrapped return printable_part From badb6d84e99a9f541392fb401a3b1a7830033b01 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Sat, 23 Mar 2024 15:01:09 -0400 Subject: [PATCH 03/17] stream --- src/chatdbg/assistant/assistant.py | 5 ++--- src/chatdbg/chatdbg_pdb.py | 9 +++++++-- src/chatdbg/ipdb_util/streamwrap.py | 8 +++++++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index d310104..342e751 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -137,16 +137,15 @@ def drain_stream(self, stream): r = event.data if r.status == "requires_action": outputs = [] + self.printer.end_stream() for tool_call in r.required_action.submit_tool_outputs.tool_calls: output = self._make_call(tool_call) outputs += [{"tool_call_id": tool_call.id, "output": output}] - + self.printer.begin_stream() try: new_stream = self.threads.runs.submit_tool_outputs( thread_id=self.thread.id, run_id=r.id, tool_outputs=outputs, stream=True ) - self.printer.end_stream() - # self.printer.begin_stream() return self.drain_stream(new_stream) except OSError as e: raise diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 679e9e3..42b6cf1 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -253,8 +253,12 @@ def _capture_onecmd(self, line): self.stdout = stdout self.lastcmd = lastcmd - def _format_history_entry(self, entry, indent=""): + def _format_history_entry(self, entry, indent="", prompt_color=None): line, output = entry + if prompt_color == None: + prompt = self.prompt + else: + prompt = t if output: entry = f"{self.prompt}{line}\n{output}" else: @@ -716,7 +720,8 @@ def stream(self, text=''): print(self.wrapper.add(text), file=self.stdout, flush=True, end='') def end_stream(self): - print(file=self.stdout) + result = self.wrapper.flush() + print(result, file=self.stdout) def complete_message(self, text=''): print(self.wrapper.add('', flush=True), file=self.stdout, flush=True, end='') diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py index 69a0f42..f8faf2a 100644 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -30,7 +30,13 @@ def add(self, text, flush=False): return printable_part def flush(self): - return self.add('', flush=True) + # if self.pending == '': + # return None + # else: + result = self.add('', flush=True) + self.buffer = '' + self.wrapped = '' + return result From a38f22164fe8a367af6611b786dcb31cddab3aa7 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Sun, 24 Mar 2024 10:35:31 -0400 Subject: [PATCH 04/17] moo --- src/chatdbg/assistant/assistant.py | 629 +++++++++++++++++++---------- src/chatdbg/chatdbg_pdb.py | 47 ++- src/chatdbg/ipdb_util/config.py | 4 + 3 files changed, 453 insertions(+), 227 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 342e751..97acc48 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -1,246 +1,457 @@ -import atexit -import textwrap import json import time -import sys +from typing import Callable -import llm_utils -from openai import * -from pydantic import BaseModel +import litellm +import openai -class AssistantPrinter: - def begin_stream(self): - print('\n', flush=True) +def sandwhich_tokens(model, text: str, max_tokens: int, top_proportion: float) -> str: + if max_tokens == None: + return text + tokens = litellm.encode(model, text) + if len(tokens) <= max_tokens: + return text + else: + total_len = max_tokens - 5 # some slop for the ... + top_len = int(top_proportion * total_len) + bot_len = int((1-top_proportion) * total_len) + return litellm.decode(model, tokens[0:top_len]) + " [...] " + litellm.decode(model, tokens[-bot_len:]) - def stream(self, text=''): - print(text, flush=True, end='') - def end_stream(self): - print('\n', flush=True) - - def complete_message(self, text=''): - print(text, flush=True) - - def log(self, json_obj): +class AssistantClient: + def warn(text): + pass + def fail(text): + pass + def response(text): + pass + def function_call(name, args, result): pass - def fail(self, message='Failed'): - print() - print(textwrap.wrap(message, width=70, initial_indent='*** ')) - sys.exit(1) - - def warn(self, message='Warning'): - print() - print(textwrap.wrap(message, width=70, initial_indent='*** ')) - - -class Assistant: - """ - An Assistant is a wrapper around OpenAI's assistant API. Example usage: - - assistant = Assistant("Assistant Name", instructions, - model='gpt-4-1106-preview', debug=True) - assistant.add_function(my_func) - response = assistant.run(user_prompt) - - Name can be any name you want. - If debug is True, it will create a log of all messages and JSON responses in - json.txt. - """ +class Assistant: def __init__( - self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, - printer = AssistantPrinter()): - self.printer = printer - try: - self.client = OpenAI(timeout=timeout) - except OpenAIError: - self.printer.fail("""\ - You need an OpenAI key to use this tool. - You can get a key here: https://platform.openai.com/api-keys. - Set the environment variable OPENAI_API_KEY to your key value.""") - - self.assistants = self.client.beta.assistants - self.threads = self.client.beta.threads - self.functions = dict() - - self.assistant = self.assistants.create( - name=name, instructions=instructions, model=model - ) - self.thread = self.threads.create() - - atexit.register(self._delete_assistant) - - def _delete_assistant(self): - if self.assistant != None: - try: - id = self.assistant.id - response = self.assistants.delete(id) - assert response.deleted - except OSError: - raise - except Exception as e: - self.printer.warn(f"Assistant {id} was not deleted ({e}). You can do so at https://platform.openai.com/assistants.") + self, + instructions, + model="gpt-4", + timeout=30, + max_call_response_tokens=None, + debug=False, + ): + if debug: + log_file = open(f"chatdbg.log", "w") + self._logger = lambda model_call_dict: print(model_call_dict, file=log_file, flush=True) + else: + self._logger = None + + self._functions = {} + self._model = model + self._timeout = timeout + self._conversation = [{"role": "system", "content": instructions}] + self._max_call_response_tokens = max_call_response_tokens def add_function(self, function): """ - Add a new function to the list of function tools for the assistant. - The function should have the necessary json spec as its pydoc string. + Add a new function to the list of function tools. + The function should have the necessary json spec as its docstring. """ - function_json = json.loads(function.__doc__) - try: - name = function_json["name"] - self.functions[name] = function - - tools = [ - {"type": "function", "function": json.loads(function.__doc__)} - for function in self.functions.values() - ] - - self.assistants.update(self.assistant.id, tools=tools) - except OpenAIError as e: - self.printer.fail(f"OpenAI Error: {e}") + schema = json.loads(function.__doc__) + assert "name" in schema, "Bad JSON in docstring for function tool." + self._functions[schema["name"]] = { + "function": function, + "schema": schema, + } - def _make_call(self, tool_call): + def _make_call(self, tool_call) -> str: name = tool_call.function.name - args = tool_call.function.arguments - - # There is a sketchy case that happens occasionally because - # the API produces a bad call... + args = json.loads(tool_call.function.arguments) try: args = json.loads(args) function = self.functions[name] result = function(**args) except OSError as e: + # function produced some error -- move this to client??? result = f"Error: {e}" except Exception as e: result = f"Ill-formed function call: {e}" return result - def drain_stream(self, stream): - run = None - for event in stream: - # print(event) - self.printer.log(event) - if event.event == 'thread.run.created': - #self.printer.begin_stream() - pass - if event.event == 'thread.run.completed': - self.printer.end_stream() - run = event.data - if event.event == 'thread.message.delta': - self.printer.stream(event.data.delta.content[0].text.value) - if event.event == 'thread.message.completed': - self.printer.complete_message(event.data.content[0].text.value) - elif event.event == 'thread.run.requires_action': - r = event.data - if r.status == "requires_action": - outputs = [] - self.printer.end_stream() - for tool_call in r.required_action.submit_tool_outputs.tool_calls: - output = self._make_call(tool_call) - outputs += [{"tool_call_id": tool_call.id, "output": output}] - self.printer.begin_stream() + + def query( + self, + prompt: str, + client: XXX + ) -> None: + start = time.time() + cost = 0 + + try: + self._conversation.append({"role": "user", "content": prompt}) + + while True: + self._conversation = litellm.trim_messages(self._conversation, self._model) + completion = litellm.completion( + model=self._model, + messages=self._conversation, + tools=[ + {"type": "function", "function": f["schema"]} + for f in self._functions.values() + ], + timeout=self._timeout, + logger_fn=self._logger + ) + + cost += litellm.completion_cost(completion) + + choice = completion.choices[0] + + if choice.finish_reason == "tool_calls": + responses = [] try: - new_stream = self.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, run_id=r.id, tool_outputs=outputs, stream=True - ) - return self.drain_stream(new_stream) - except OSError as e: - raise + for tool_call in choice.message.tool_calls: + function_response = self._make_call(tool_call) + function_response = sandwhich_tokens(self._model, + function_response, + self._max_call_response_tokens, + 0.5) + response = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": function_response, + } + responses.append(response) + self._conversation.append(choice.message) + self._conversation.extend(responses) except Exception as e: - # silent failure because the tool call submit biffed. Not muchw e can do - pass - elif event.event == 'thread.run.failed': - run = event.data - self.printer.fail(f"Internal Failure ({run.last_error.code}): {run.last_error.message}") - elif event.event == 'error': - self.printer.fail(f"Internal Failure:** {event.data}") - return run - - def run(self, prompt): - """ - Give the prompt to the assistant and get the response, which may included - intermediate function calls. - All output is printed to the given file. - """ + # Warning: potential infinite loop. + client.warn(f"Error processing tool calls: {e}") + elif choice.finish_reason == "stop": + break + else: + client.fail(f"Completation reason not supported: {choice.finish_reason}") + return - if self.assistant == None: + elapsed = time.time() - start return { - "tokens": 0, - "prompt_tokens": 0, - "completion_tokens": 0, - "model": self.assistant.model, - "cost": 0, + "cost": cost, + "time": elapsed, + "model": self._model, + "tokens": litellm.token_counter(self._conversation), } + except openai.OpenAIError as e: + client.fail(f"Internal Error: {e}") + - start_time = time.perf_counter() +# import atexit +# import textwrap +# import json +# import time +# import sys + +# import llm_utils +# from openai import * +# from openai import AssistantEventHandler +# from pydantic import BaseModel + +# class AssistantPrinter: +# def begin_stream(self): +# print('\n', flush=True) + +# def stream(self, text=''): +# print(text, flush=True, end='') + +# def end_stream(self): +# print('\n', flush=True) + +# def complete_message(self, text=''): +# print(text, flush=True) + +# def log(self, json_obj): +# pass + +# def fail(self, message='Failed'): +# print() +# print(textwrap.wrap(message, width=70, initial_indent='*** ')) -- wrap then indent +# sys.exit(1) + +# def warn(self, message='Warning'): +# print() +# print(textwrap.wrap(message, width=70, initial_indent='*** ')) -- wrap then indent - assert len(prompt) <= 32768 - self.threads.messages.create( - thread_id=self.thread.id, role="user", content=prompt - ) - with self.threads.runs.create( - thread_id=self.thread.id, - assistant_id=self.assistant.id, - stream=True - ) as stream: - run = self.drain_stream(stream) - - end_time = time.perf_counter() - elapsed_time = end_time - start_time - - cost = llm_utils.calculate_cost( - run.usage.prompt_tokens, - run.usage.completion_tokens, - self.assistant.model, - ) - return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, - "model": self.assistant.model, - "cost": cost, - "time": elapsed_time, - "thread.id": self.thread.id, - "thread": self.thread, - "run.id": run.id, - "run": run, - "assistant.id": self.assistant.id, - } +# class Assistant: +# """ +# An Assistant is a wrapper around OpenAI's assistant API. Example usage: + +# assistant = Assistant("Assistant Name", instructions, +# model='gpt-4-1106-preview', debug=True) +# assistant.add_function(my_func) +# response = assistant.run(user_prompt) + +# Name can be any name you want. + +# If debug is True, it will create a log of all messages and JSON responses in +# json.txt. +# """ + +# def __init__( +# self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, +# printer = AssistantPrinter()): +# self.printer = printer +# try: +# self.client = OpenAI(timeout=timeout) +# except OpenAIError: +# self.printer.fail("""\ +# You need an OpenAI key to use this tool. +# You can get a key here: https://platform.openai.com/api-keys. +# Set the environment variable OPENAI_API_KEY to your key value.""") + +# self.assistants = self.client.beta.assistants +# self.threads = self.client.beta.threads +# self.functions = dict() + +# self.assistant = self.assistants.create( +# name=name, instructions=instructions, model=model +# ) +# self.thread = self.threads.create() + +# atexit.register(self._delete_assistant) + +# def _delete_assistant(self): +# if self.assistant != None: +# try: +# id = self.assistant.id +# response = self.assistants.delete(id) +# assert response.deleted +# except OSError: +# raise +# except Exception as e: +# self.printer.warn(f"Assistant {id} was not deleted ({e}). You can do so at https://platform.openai.com/assistants.") + +# def add_function(self, function): +# """ +# Add a new function to the list of function tools for the assistant. +# The function should have the necessary json spec as its pydoc string. +# """ +# function_json = json.loads(function.__doc__) +# try: +# name = function_json["name"] +# self.functions[name] = function + +# tools = [ +# {"type": "function", "function": json.loads(function.__doc__)} +# for function in self.functions.values() +# ] + +# self.assistants.update(self.assistant.id, tools=tools) +# except OpenAIError as e: +# self.printer.fail(f"OpenAI Error: {e}") + +# def _make_call(self, tool_call): +# name = tool_call.function.name +# args = tool_call.function.arguments + +# # There is a sketchy case that happens occasionally because +# # the API produces a bad call... +# try: +# args = json.loads(args) +# function = self.functions[name] +# result = function(**args) +# except OSError as e: +# result = f"Error: {e}" +# except Exception as e: +# result = f"Ill-formed function call: {e}" +# return result + +# def drain_stream(self, stream): +# run = None +# for event in stream: +# if event.event not in [ 'thread.message.delta', 'thread.run.step.delta' ]: +# print(event.event) +# self.printer.log(event) +# if event.event in [ 'thread.run.created', 'thread.run.in_progress' ]: +# run = event.data +# elif event.event == 'thread.run.completed': +# self.printer.end_stream() +# return event.data +# elif event.event == 'thread.message.delta': +# self.printer.stream(event.data.delta.content[0].text.value) +# elif event.event == 'thread.message.completed': +# self.printer.complete_message(event.data.content[0].text.value) +# elif event.event == 'thread.run.requires_action': +# r = event.data +# if r.status == "requires_action": +# outputs = [] +# self.printer.end_stream() +# for tool_call in r.required_action.submit_tool_outputs.tool_calls: +# output = self._make_call(tool_call) +# outputs += [{"tool_call_id": tool_call.id, "output": output}] +# self.printer.begin_stream() +# try: +# with self.threads.runs.submit_tool_outputs( +# thread_id=self.thread.id, run_id=r.id, tool_outputs=outputs, stream=True +# ) as new_stream: +# _ = self.drain_stream(new_stream) +# except OSError as e: +# raise +# except Exception as e: +# # silent failure because the tool call submit biffed. Not muchw e can do +# pass +# elif event.event == 'thread.run.failed': +# run = event.data +# self.printer.fail(f"Internal Failure ({run.last_error.code}): {run.last_error.message}") +# elif event.event == 'error': +# self.printer.fail(f"Internal Failure:** {event.data}") +# print('***', run) +# return run + +# def run(self, prompt): +# """ +# Give the prompt to the assistant and get the response, which may included +# intermediate function calls. +# All output is printed to the given file. +# """ + +# if self.assistant == None: +# return { +# "tokens": 0, +# "prompt_tokens": 0, +# "completion_tokens": 0, +# "model": self.assistant.model, +# "cost": 0, +# } + +# start_time = time.perf_counter() + +# assert len(prompt) <= 32768 +# self.threads.messages.create( +# thread_id=self.thread.id, role="user", content=prompt +# ) + +# class EventHandler(AssistantEventHandler): +# def on_event(self, event): +# print(event.event) + +# with self.threads.runs.create_and_stream( +# thread_id=self.thread.id, +# assistant_id=self.assistant.id, +# # stream=True +# event_handler=EventHandler(), +# ) as stream: +# self.drain_stream(stream) + +# end_time = time.perf_counter() +# elapsed_time = end_time - start_time + +# cost = llm_utils.calculate_cost( +# run.usage.prompt_tokens, +# run.usage.completion_tokens, +# self.assistant.model, +# ) +# return { +# "tokens": run.usage.total_tokens, +# "prompt_tokens": run.usage.prompt_tokens, +# "completion_tokens": run.usage.completion_tokens, +# "model": self.assistant.model, +# "cost": cost, +# "time": elapsed_time, +# "thread.id": self.thread.id, +# "thread": self.thread, +# "run.id": run.id, +# "run": run, +# "assistant.id": self.assistant.id, +# } + +# return {} + +# if __name__ == '__main__': +# def weather(location): +# """ +# { +# "name": "get_weather", +# "description": "Determine weather in my location", +# "parameters": { +# "type": "object", +# "properties": { +# "location": { +# "type": "string", +# "description": "The city and state e.g. San Francisco, CA" +# }, +# "unit": { +# "type": "string", +# "enum": [ +# "c", +# "f" +# ] +# } +# }, +# "required": [ +# "location" +# ] +# } +# } +# """ +# return "Sunny and 72 degrees." + +# a = Assistant("Test", "You generate text.") +# a.add_function(weather) +# x = a.run("What's the weather in Boston?") +# print(x) + + + # def _print_message( + # self, message, indent, append_message: Callable[[str], None], wrap=120 + # ) -> None: + # def _format_message(indent) -> str: + # tool_calls = None + # if "tool_calls" in message: + # tool_calls = message["tool_calls"] + # elif hasattr(message, "tool_calls"): + # tool_calls = message.tool_calls + + # content = None + # if "content" in message: + # content = message["content"] + # elif hasattr(message, "content"): + # content = message.content + + # assert content != None or tool_calls != None + + # # The longest role string is 'assistant'. + # max_role_length = 9 + # # We add 3 characters for the brackets and space. + # subindent = indent + max_role_length + 3 + + # role = message["role"].upper() + # role_indent = max_role_length - len(role) + + # output = "" + + # if content != None: + # content = llm_utils.word_wrap_except_code_blocks( + # content, wrap - len(role) - indent - 3 + # ) + # first, *rest = content.split("\n") + # output += f"{' ' * indent}[{role}]{' ' * role_indent} {first}\n" + # for line in rest: + # output += f"{' ' * subindent}{line}\n" + + # if tool_calls != None: + # if content != None: + # output += f"{' ' * subindent} Function calls:\n" + # else: + # output += ( + # f"{' ' * indent}[{role}]{' ' * role_indent} Function calls:\n" + # ) + # for tool_call in tool_calls: + # arguments = json.loads(tool_call.function.arguments) + # output += f"{' ' * (subindent + 4)}{tool_call.function.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})\n" + # return output + + # append_message(_format_message(indent)) + # if self._log: + # print(_format_message(0), file=self._log) -if __name__ == '__main__': - def weather(location): - """ - { - "name": "get_weather", - "description": "Determine weather in my location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } - }, - "required": [ - "location" - ] - } - } - """ - return "Sunny and 72 degrees." - a = Assistant("Test", "You generate text.") - a.add_function(weather) - x = a.run("What's the weather in Boston?") - print(x) \ No newline at end of file diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 42b6cf1..fda234c 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -93,7 +93,6 @@ def __init__(self, *args, **kwargs): chatdbg_config = Chat() sys.stdin = CaptureInput(sys.stdin) - self.wrapper = StreamTextWrapper(indent=self._chat_prefix, width=self._text_width) # Only use flow when we are in jupyter or using stdin in ipython. In both # cases, there will be no python file at the start of argv after the @@ -253,12 +252,8 @@ def _capture_onecmd(self, line): self.stdout = stdout self.lastcmd = lastcmd - def _format_history_entry(self, entry, indent="", prompt_color=None): + def _format_history_entry(self, entry, indent=""): line, output = entry - if prompt_color == None: - prompt = self.prompt - else: - prompt = t if output: entry = f"{self.prompt}{line}\n{output}" else: @@ -700,7 +695,8 @@ def slice(name): "ChatDBG", instruction_prompt, model=chatdbg_config.model, - printer=self + printer=ChatAssistantOutput(self.stdout, self._log, self._chat_prefix, + self._text_width, chatdbg_config.stream_response) ) if chatdbg_config.take_the_wheel: @@ -713,31 +709,46 @@ def slice(name): ############################################################### + +class ChatAssistantOutput: + def __init__(self, stdout, prefix, width, chat_log, stream_response): + self.stdout = stdout + self.chat_log = chat_log + self.prefix = prefix + self.width = width + if stream_response and False: + self.streamer = StreamTextWrapper(indent=self.prefix, width=self.width) + else: + self.streamer = None + def begin_stream(self): - print(file=self.stdout) + if self.streamer: + print(file=self.stdout) def stream(self, text=''): - print(self.wrapper.add(text), file=self.stdout, flush=True, end='') + if self.streamer: + print(self.streamer.add(text), file=self.stdout, flush=True, end='') def end_stream(self): - result = self.wrapper.flush() - print(result, file=self.stdout) + if self.streamer: + print(self.streamer.flush(), file=self.stdout) def complete_message(self, text=''): - print(self.wrapper.add('', flush=True), file=self.stdout, flush=True, end='') - line = llm_utils.word_wrap_except_code_blocks(text, self._text_width - 10) - self._log.message(line) - # line = textwrap.indent(line, self._chat_prefix, lambda _: True) - # print(line, file=self.stdout, flush=True) + line = llm_utils.word_wrap_except_code_blocks(text, self.width - 5) + self.log.message(line) + if self.streamer: + print(self.streamer.add('', flush=True), file=self.stdout, flush=True, end='') + else: + line = textwrap.indent(line, self.prefix, lambda _: True) + print(line, file=self.stdout, flush=True) def log(self, json_obj): if chatdbg_config.debug: - self._log.log(json_obj) + self.chat_log.log(json_obj) def fail(self, message='Failed'): print(file=self.stdout) print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) - raise Exception() sys.exit(1) def warn(self, message='Warning'): diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/ipdb_util/config.py index 2b4e228..ae845ca 100644 --- a/src/chatdbg/ipdb_util/config.py +++ b/src/chatdbg/ipdb_util/config.py @@ -45,6 +45,9 @@ class Chat(Configurable): take_the_wheel = Bool( chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" ).tag(config=True) + stream_response = Bool( + chat_get_env("stream_response", True), help="Stream the response at it arrives" + ).tag(config=True) def to_json(self): """Serialize the object to a JSON string.""" @@ -59,4 +62,5 @@ def to_json(self): "show_libs": self.show_libs, "show_slices": self.show_slices, "take_the_wheel": self.take_the_wheel, + "stream_response": self.stream_response, } From 6766e6b69aab9f72b2df25ee45d75613352e3392 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Sun, 24 Mar 2024 10:43:21 -0400 Subject: [PATCH 05/17] moo --- src/chatdbg/assistant/assistant.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 97acc48..962a2e6 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -54,13 +54,16 @@ def __init__( def add_function(self, function): """ Add a new function to the list of function tools. - The function should have the necessary json spec as its docstring. + The function should have the necessary json spec as its docstring, with + this format: + "schema": function schema, + "format": format to print call, """ schema = json.loads(function.__doc__) assert "name" in schema, "Bad JSON in docstring for function tool." self._functions[schema["name"]] = { "function": function, - "schema": schema, + "schema": schema } def _make_call(self, tool_call) -> str: @@ -95,7 +98,7 @@ def query( model=self._model, messages=self._conversation, tools=[ - {"type": "function", "function": f["schema"]} + {"type": "function", "function": f["schema"]["schema"]} for f in self._functions.values() ], timeout=self._timeout, From d41f232ea3e6e887fac99639f8ea2ca85f858beb Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Mon, 25 Mar 2024 10:14:34 -0400 Subject: [PATCH 06/17] logging in progress --- src/chatdbg/assistant/assistant.py | 247 +++++++++++++++---- src/chatdbg/chatdbg_pdb.py | 375 ++++++++++++++++------------- src/chatdbg/ipdb_util/chatlog.py | 171 +++++++------ 3 files changed, 496 insertions(+), 297 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 962a2e6..815a117 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -1,42 +1,98 @@ import json import time from typing import Callable - +import sys import litellm import openai +import textwrap + +import abc +import textwrap +import sys + +class AbstractAssistantClient(abc.ABC): + + @abc.abstractmethod + def begin_dialog(self, instructions): + pass + + @abc.abstractmethod + def end_dialog(self): + pass + + @abc.abstractmethod + def begin_query(self, user_prompt): + pass + + @abc.abstractmethod + def end_query(self, stats): + pass + + @abc.abstractmethod + def warn(self, text): + pass + + @abc.abstractmethod + def fail(self, text): + pass -def sandwhich_tokens(model, text: str, max_tokens: int, top_proportion: float) -> str: - if max_tokens == None: - return text - tokens = litellm.encode(model, text) - if len(tokens) <= max_tokens: - return text - else: - total_len = max_tokens - 5 # some slop for the ... - top_len = int(top_proportion * total_len) - bot_len = int((1-top_proportion) * total_len) - return litellm.decode(model, tokens[0:top_len]) + " [...] " + litellm.decode(model, tokens[-bot_len:]) - - -class AssistantClient: - def warn(text): + @abc.abstractmethod + def stream(self, event, text): pass - def fail(text): + + @abc.abstractmethod + def response(self, text): pass - def response(text): + + @abc.abstractmethod + def function_call(self, call, result): pass - def function_call(name, args, result): + + +class PrintintAssistantClient(AbstractAssistantClient): + def __init__(self, out=sys.stdout): + self.out = out + + def warn(self, text): + print(textwrap.indent(text, '*** '), file=self.out) + + def fail(self, text): + print(textwrap.indent(text, '*** '), file=self.out) + sys.exit(1) + + def stream(self, event, text): + # begin / none, step / delta , complete / full pass + def begin_query(self, user_prompt): + pass + + def end_query(self, stats): + pass + + def response(self, text): + if text != None: + print(text, file=self.out) + + def function_call(self, call, result): + if result and len(result) > 0: + entry = f"{call}\n{result}" + else: + entry = f"{call}" + print(entry, file=self.out) + + class Assistant: def __init__( self, instructions, - model="gpt-4", + model="gpt-3.5-turbo-1106", timeout=30, - max_call_response_tokens=None, + clients = [ PrintintAssistantClient() ], + functions=[], + max_call_response_tokens=4096, debug=False, ): if debug: @@ -45,34 +101,86 @@ def __init__( else: self._logger = None + self._clients = clients + self._functions = {} + for f in functions: + self._add_function(f) + self._model = model self._timeout = timeout self._conversation = [{"role": "system", "content": instructions}] self._max_call_response_tokens = max_call_response_tokens - def add_function(self, function): + self.check_model() + + def broadcast(self, method_name, *args, **kwargs): + for client in self._clients: + method = getattr(client, method_name, None) + if callable(method): + method(*args, **kwargs) + + def check_model(self): + result = litellm.validate_environment(self._model) + missing_keys = result["missing_keys"] + if missing_keys != []: + _, provider, _, _ = litellm.get_llm_provider(self._model) + if provider == 'openai': + self.broadcast('fail', textwrap.dedent(f"""\ + You need an OpenAI key to use the {self._model} model. + You can get a key here: https://platform.openai.com/api-keys. + Set the environment variable OPENAI_API_KEY to your key value.""")) + sys.exit(1) + else: + self.broadcast('fail', textwrap.dedent(f"""\ + You need to set the following environment variables + to use the {self._model} model: {', '.join(missing_keys)}""")) + sys.exit(1) + + if not litellm.supports_function_calling(self._model): + self.broadcast('fail', textwrap.dedent(f"""\ + The {self._model} model does not support function calls. + You must use a model that does, eg. gpt-4.""")) + sys.exit(1) + + def sandwhich_tokens(self, text: str, max_tokens: int, top_proportion: float) -> str: + model = self._model + if max_tokens == None: + return text + tokens = litellm.encode(model, text) + if len(tokens) <= max_tokens: + return text + else: + total_len = max_tokens - 5 # some slop for the ... + top_len = int(top_proportion * total_len) + bot_len = int((1-top_proportion) * total_len) + return litellm.decode(model, tokens[0:top_len]) + " [...] " + litellm.decode(model, tokens[-bot_len:]) + + + + def _add_function(self, function): """ Add a new function to the list of function tools. The function should have the necessary json spec as its docstring, with this format: - "schema": function schema, - "format": format to print call, + "schema": function schema + "format": format to print call """ schema = json.loads(function.__doc__) - assert "name" in schema, "Bad JSON in docstring for function tool." - self._functions[schema["name"]] = { + assert "name" in schema['schema'], "Bad JSON in docstring for function tool." + self._functions[schema['schema']["name"]] = { "function": function, "schema": schema } def _make_call(self, tool_call) -> str: name = tool_call.function.name - args = json.loads(tool_call.function.arguments) try: - args = json.loads(args) - function = self.functions[name] - result = function(**args) + args = json.loads(tool_call.function.arguments) + function = self._functions[name] + call = function["schema"]["format"].format_map(args) + result = function["function"](**args) + self.broadcast('function_call', call, result) except OSError as e: # function produced some error -- move this to client??? result = f"Error: {e}" @@ -83,17 +191,16 @@ def _make_call(self, tool_call) -> str: def query( self, - prompt: str, - client: XXX - ) -> None: + prompt: str + ): start = time.time() cost = 0 try: self._conversation.append({"role": "user", "content": prompt}) - + while True: - self._conversation = litellm.trim_messages(self._conversation, self._model) + self._conversation = litellm.utils.trim_messages(self._conversation, self._model) completion = litellm.completion( model=self._model, messages=self._conversation, @@ -107,14 +214,18 @@ def query( cost += litellm.completion_cost(completion) - choice = completion.choices[0] + response_message = completion.choices[0].message + self._conversation.append(response_message) + + if response_message.content: + self.broadcast('response', '(Message) ' + response_message.content) - if choice.finish_reason == "tool_calls": - responses = [] + if completion.choices[0].finish_reason == 'tool_calls': + tool_calls = response_message.tool_calls try: - for tool_call in choice.message.tool_calls: + for tool_call in tool_calls: function_response = self._make_call(tool_call) - function_response = sandwhich_tokens(self._model, + function_response = self.sandwhich_tokens( function_response, self._max_call_response_tokens, 0.5) @@ -124,27 +235,63 @@ def query( "name": tool_call.function.name, "content": function_response, } - responses.append(response) - self._conversation.append(choice.message) - self._conversation.extend(responses) + self._conversation.append(response) + self.broadcast('response', '') except Exception as e: # Warning: potential infinite loop. - client.warn(f"Error processing tool calls: {e}") - elif choice.finish_reason == "stop": - break + self.broadcast('warn', f"Error processing tool calls: {e}") else: - client.fail(f"Completation reason not supported: {choice.finish_reason}") - return + break elapsed = time.time() - start return { "cost": cost, "time": elapsed, "model": self._model, - "tokens": litellm.token_counter(self._conversation), + "tokens": completion.usage.total_tokens, + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, } except openai.OpenAIError as e: - client.fail(f"Internal Error: {e}") + self.broadcast('fail', f"Internal Error: {e.__dict__}") + sys.exit(1) + +if __name__ == '__main__': + def weather(location): + """ + { + "schema":{ + "name": "get_weather", + "description": "Determine weather in my location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + }, + "format": "(ChatDBG) weather in {location}" + } + """ + return "Sunny and 72 degrees." + + a = Assistant("You generate text.") + a.add_function(weather) + x = a.query("What's the weather in Boston?") + print(x) # import atexit diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index fda234c..c00eaa9 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -1,14 +1,13 @@ -import linecache import ast -import atexit import inspect +import linecache import os import pdb import pydoc import sys import textwrap import traceback -from io import StringIO, TextIOWrapper +from io import StringIO from pathlib import Path from pprint import pprint @@ -18,29 +17,18 @@ from chatdbg.ipdb_util.capture import CaptureInput -from .assistant.assistant import Assistant -from .ipdb_util.config import Chat +from .assistant.assistant import Assistant, AbstractAssistantClient from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper -from .ipdb_util.prompts import pdb_instructions -from .ipdb_util.text import * +from .ipdb_util.config import Chat from .ipdb_util.locals import * +from .ipdb_util.prompts import pdb_instructions from .ipdb_util.streamwrap import StreamTextWrapper - -_valid_models = [ - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - "gpt-4", # no parallel calls - "gpt-3.5-turbo", # no parallel calls -] +from .ipdb_util.text import * chatdbg_config: Chat = None def load_ipython_extension(ipython): - # Create an instance of your configuration class with IPython's config global chatdbg_config from chatdbg.chatdbg_pdb import Chat, ChatDBG @@ -112,11 +100,10 @@ def __init__(self, *args, **kwargs): self.do_context(chatdbg_config.context) self.rcLines += ast.literal_eval(chatdbg_config.rc_lines) - # set this to True ONLY AFTER we have had stack frames + # set this to True ONLY AFTER we have had access to stack frames self._show_locals = False self._log = ChatDBGLog(chatdbg_config) - atexit.register(lambda: self._log.dump()) def _is_user_frame(self, frame): if not self._is_user_file(frame.f_code.co_filename): @@ -188,9 +175,7 @@ def _hide_lib_frames(self): self._error_specific_prompt += f"The code `{current_line.strip()}` is correct and MUST remain unchanged in your fix.\n" def execRcLines(self): - # do before running rclines -- our stack should be set up by now. - if not chatdbg_config.show_libs: self._hide_lib_frames() self._error_stack_trace = f"The program has the following stack trace:\n```\n{self.format_stack_trace()}\n```\n" @@ -215,15 +200,13 @@ def onecmd(self, line: str) -> bool: return super().onecmd(line) finally: self.stdout = hist_file.getfile() - if not line.startswith("config") and not line.startswith("mark"): - output = strip_color(hist_file.getvalue()) - if line not in ["quit", "EOF"]: - self._log.user_command(line, output) - if ( - line not in ["hist", "test_prompt", "c", "continue"] - and not self.was_chat - ): - self._history += [(line, output)] + output = strip_color(hist_file.getvalue()) + self._log.user_command(line, output) + if ( + line.split(' ')[0] not in ["hist", "test_prompt", "c", "cont", "continue", "config"] + and not self.was_chat + ): + self._history += [(line, output)] def message(self, msg) -> None: """ @@ -247,6 +230,7 @@ def _capture_onecmd(self, line): self.stdout = StringIO() super().onecmd(line) result = self.stdout.getvalue().rstrip() + result = strip_color(result) return result finally: self.stdout = stdout @@ -361,11 +345,6 @@ def do_slice(self, arg): break index -= 1 if _x != None: - # print('found it') - # print(_x) - # print(_x.__dict__) - # print(_x._get_timestamps_for_version(version=-1)) - # print(code(_x)) time_stamps = _x._get_timestamps_for_version(version=-1) time_stamps = [ts for ts in time_stamps if ts.cell_num > -1] result = str( @@ -434,7 +413,7 @@ def _hidden_predicate(self, frame): return False def print_stack_trace(self, context=None, locals=None): - # override to print the skips into stdout... + # override to print the skips into stdout instead of stderr... Colors = self.color_scheme_table.active_colors ColorsNormal = Colors.Normal if context is None: @@ -515,7 +494,7 @@ def _stack_prompt(self): ) stack = ( textwrap.dedent( - f""" + f"""\ This is the current stack. The current frame is indicated by an arrow '>' at the start of the line. ```""" @@ -526,31 +505,48 @@ def _stack_prompt(self): finally: self.stdout = stdout - def _build_prompt(self, arg, conversing): - prompt = "" - - if not conversing: - stack_dump = f"The program has this stack trace:\n```\n{self.format_stack_trace()}\n```\n\n" - prompt = "\n" + stack_dump + self._error_specific_prompt - if len(sys.argv) > 1: - prompt += f"\nThese were the command line options:\n```\n{' '.join(sys.argv)}\n```\n" - input = sys.stdin.get_captured_input() - if len(input) > 0: - prompt += f"\nThis was the program's input :\n```\n{input}```\n" - + def _ip_instructions(self): + return pdb_instructions( + self._supports_flow, chatdbg_config.take_the_wheel + ) + def _ip_enchriched_stack_trace(self): + return f"The program has this stack trace:\n```\n{self.format_stack_trace()}\n```\n" + + def _ip_error(self): + return self._error_specific_prompt + + def _ip_inputs(self): + inputs = '' + if len(sys.argv) > 1: + inputs += f"\nThese were the command line options:\n```\n{' '.join(sys.argv)}\n```\n" + input = sys.stdin.get_captured_input() + if len(input) > 0: + inputs += f"\nThis was the program's input :\n```\n{input}```\n" + return inputs + + def _ip_history(self): if len(self._history) > 0: hist = textwrap.indent(self._capture_onecmd("hist"), "") - self._clear_history() hist = f"\nThis is the history of some pdb commands I ran and the results.\n```\n{hist}\n```\n" - prompt += hist - - if arg == "why": - arg = "Explain the root cause of the error." + return hist + else: + return '' - stack = self._stack_prompt() - prompt += stack + "\n" + arg + def concat_prompt(self, *args): + args = [a for a in args if len(a) > 0] + return "\n".join(args) - return prompt + def _build_prompt(self, arg, conversing): + if not conversing: + return self.concat_prompt( self._ip_enchriched_stack_trace(), + self._ip_inputs(), + self._ip_error(), + self._ip_history(), + arg) + else: + return self.concat_prompt(self._ip_history(), + self._stack_prompt(), + arg) def do_chat(self, arg): """chat/: @@ -559,16 +555,16 @@ def do_chat(self, arg): self.was_chat = True full_prompt = self._build_prompt(arg, self._assistant != None) + full_prompt = strip_color(full_prompt) + full_prompt = truncate_proportionally(full_prompt) + + self._clear_history() if self._assistant == None: self._make_assistant() - - full_prompt = strip_color(full_prompt) - full_prompt = truncate_proportionally(full_prompt) - self._log.push_chat(arg, full_prompt) - stats = self._assistant.run(full_prompt) + stats = self._assistant.query(full_prompt) self._log.pop_chat(stats) self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") @@ -591,166 +587,201 @@ def do_config(self, arg): except TraitError as e: self.error(f"{e}") - # Get the documentation and source code (if available) for any function or method visible in the current frame. The argument to info can be the name of the function or an expression of the form `obj.method_name` to see the information for the method_name method of object obj.", - def _make_assistant(self): def info(value): """ { - "name": "info", - "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", - "parameters": { - "type": "object", - "properties": { - "value": { - "type": "string", - "description": "The reference to get the information for." - } - }, - "required": [ "value" ] - } + "schema": { + "name": "info", + "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", + "parameters": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The reference to get the information for." + } + }, + "required": [ "value" ] + } + }, + "format": "info {value}" } """ command = f"info {value}" result = self._capture_onecmd(command) - self.message( - self._format_history_entry((command, result), indent=self._chat_prefix) - ) - result = strip_color(result) self._log.function(command, result) return truncate_proportionally(result, top_proportion=1) def debug(command): """ { - "name": "debug", - "description": "Run a pdb command and get the response.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The pdb command to run." - } - }, - "required": [ "command" ] - } + "schema" : { + "name": "debug", + "description": "Run a pdb command and get the response.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The pdb command to run." + } + }, + "required": [ "command" ] + } + }, + "format": "{command}" } """ cmd = command if command != "list" else "ll" + # old_curframe = self.curframe result = self._capture_onecmd(cmd) - - self.message( - self._format_history_entry((command, result), indent=self._chat_prefix) - ) - - result = strip_color(result) self._log.function(command, result) # help the LLM know where it is... - result += strip_color(self._stack_prompt()) + # if old_curframe != self.curframe: + # result += strip_color(self._stack_prompt()) + return truncate_proportionally(result, maxlen=8000, top_proportion=0.9) def slice(name): """ { - "name": "slice", - "description": "Return the code to compute a global variable used in the current frame", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The variable to look at." - } - }, - "required": [ "name" ] - } + "schema": { + "name": "slice", + "description": "Return the code to compute a global variable used in the current frame", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The variable to look at." + } + }, + "required": [ "name" ] + } + }, + "format": "slice {name}" } - """ command = f"slice {name}" result = self._capture_onecmd(command) - self.message( - self._format_history_entry((command, result), indent=self._chat_prefix) - ) - result = strip_color(result) self._log.function(command, result) return truncate_proportionally(result, top_proportion=0.5) - self._clear_history() - instruction_prompt = pdb_instructions( - self._supports_flow, chatdbg_config.take_the_wheel - ) - + instruction_prompt = self._ip_instructions() self._log.instructions(instruction_prompt) - if not chatdbg_config.model in _valid_models: - print( - f"'{chatdbg_config.model}' is not a valid OpenAI model. Choose from: {_valid_models}." - ) - sys.exit(0) + if chatdbg_config.take_the_wheel: + functions = [ debug, info ] + if self._supports_flow: + functions += [ slice ] + else: + functions = [] self._assistant = Assistant( - "ChatDBG", instruction_prompt, model=chatdbg_config.model, - printer=ChatAssistantOutput(self.stdout, self._log, self._chat_prefix, - self._text_width, chatdbg_config.stream_response) + debug=chatdbg_config.debug, + functions=functions, + clients=[ ChatAssistantClient(self.stdout, + self.prompt, + self._chat_prefix, + self._text_width) ] ) - if chatdbg_config.take_the_wheel: - self._assistant.add_function(debug) - self._assistant.add_function(info) - - if self._supports_flow: - self._assistant.add_function(slice) ############################################################### -class ChatAssistantOutput: - def __init__(self, stdout, prefix, width, chat_log, stream_response): - self.stdout = stdout - self.chat_log = chat_log - self.prefix = prefix + +class ChatAssistantClient(AbstractAssistantClient): + def __init__(self, out, debugger_prompt, chat_prefix, width): + self.out = out + self.debugger_prompt = debugger_prompt + self.chat_prefix = chat_prefix self.width = width - if stream_response and False: - self.streamer = StreamTextWrapper(indent=self.prefix, width=self.width) - else: - self.streamer = None - - def begin_stream(self): - if self.streamer: - print(file=self.stdout) + self._assistant = None + + # Call backs - def stream(self, text=''): - if self.streamer: - print(self.streamer.add(text), file=self.stdout, flush=True, end='') + def begin_query(self, user_prompt): + pass - def end_stream(self): - if self.streamer: - print(self.streamer.flush(), file=self.stdout) + def end_query(self, stats): + pass - def complete_message(self, text=''): - line = llm_utils.word_wrap_except_code_blocks(text, self.width - 5) - self.log.message(line) - if self.streamer: - print(self.streamer.add('', flush=True), file=self.stdout, flush=True, end='') - else: - line = textwrap.indent(line, self.prefix, lambda _: True) - print(line, file=self.stdout, flush=True) + def _print(self, text): + print(textwrap.indent(text, self.chat_prefix, lambda _: True), file=self.out) - def log(self, json_obj): - if chatdbg_config.debug: - self.chat_log.log(json_obj) + def warn(self, text): + self._print(textwrap.indent(text, '*** ')) - def fail(self, message='Failed'): - print(file=self.stdout) - print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) + def fail(self, text): + self._print(textwrap.indent(text, '*** ')) sys.exit(1) + + def stream(self, event, text): + # begin / none, step / delta , complete / full + pass + + def response(self, text): + if text != None: + text = llm_utils.utils.word_wrap_except_code_blocks(text, self.width-len(self.chat_prefix)) + self._print(text) + + def function_call(self, call, result): + if result and len(result) > 0: + entry = f"{self.debugger_prompt}{call}\n{result}" + else: + entry = f"{self.debugger_prompt}{call}" + self._print(entry) + + + + +# class ChatAssistantOutput: +# def __init__(self, stdout, prefix, width, chat_log, stream_response): +# self.stdout = stdout +# self.chat_log = chat_log +# self.prefix = prefix +# self.width = width +# if stream_response and False: +# self.streamer = StreamTextWrapper(indent=self.prefix, width=self.width) +# else: +# self.streamer = None + +# def begin_stream(self): +# if self.streamer: +# print(file=self.stdout) + +# def stream(self, text=''): +# if self.streamer: +# print(self.streamer.add(text), file=self.stdout, flush=True, end='') + +# def end_stream(self): +# if self.streamer: +# print(self.streamer.flush(), file=self.stdout) + +# def complete_message(self, text=''): +# line = llm_utils.word_wrap_except_code_blocks(text, self.width - 5) +# self.log.message(line) +# if self.streamer: +# print(self.streamer.add('', flush=True), file=self.stdout, flush=True, end='') +# else: +# line = textwrap.indent(line, self.prefix, lambda _: True) +# print(line, file=self.stdout, flush=True) + +# def log(self, json_obj): +# if chatdbg_config.debug: +# self.chat_log.log(json_obj) + +# def fail(self, message='Failed'): +# print(file=self.stdout) +# print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) +# sys.exit(1) - def warn(self, message='Warning'): - print(file=self.stdout) - print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) +# def warn(self, message='Warning'): +# print(file=self.stdout) +# print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index 28ab1e6..4fabf04 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -1,10 +1,57 @@ +import atexit from io import StringIO from datetime import datetime import uuid import sys import yaml -from .config import Chat - +from pydantic import BaseModel +from typing import List, Union, Optional + +class Output(BaseModel): + type: str + output: Optional[str] = None + +class FunctionOutput(Output): + type: str = "call" + +class TextOutput(Output): + type: str = "text" + +class ChatOutput(BaseModel): + type: str = "chat" + outputs: List[Output] = [] + +class Stats(BaseModel): + tokens: int = 0 + cost: float = 0.0 + time: float = 0.0 + + class Config: + extra = 'allow' + + +class Step(BaseModel): + input: str + output: Union[TextOutput, ChatOutput] + full_prompt: Optional[str] = None + stats: Optional[Stats] = None + +class Meta(BaseModel): + time: datetime + command_line: str + uid: str + config: str + mark: str = "?" + total_tokens: int = 0 + total_time: float = 0.0 + total_cost: float = 0.0 + +class Log(BaseModel): + meta: Meta + steps: List[Step] + instructions: str + stdout: Optional[str] + stderr: Optional[str] class CopyingTextIOWrapper: """ @@ -31,98 +78,72 @@ def __getattr__(self, attr): class ChatDBGLog: - def __init__(self, config: Chat): + def __init__(self, log_filename, config, capture_streams=True): + self.meta = Meta( + time=datetime.now(), + command_line=" ".join(sys.argv), + uid=str(uuid.uuid4()), + config=config + ) self.steps = [] - self.meta = { - "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "command_line": " ".join(sys.argv), - "uid": str(uuid.uuid4()), - "config": config.to_json(), - "mark": "?", - } - self.log_file = config.log + self.log_filename = log_filename self._instructions = "" - self.stdout_wrapper = CopyingTextIOWrapper(sys.stdout) - self.stderr_wrapper = CopyingTextIOWrapper(sys.stderr) - sys.stdout = self.stdout_wrapper - sys.stderr = self.stdout_wrapper - self.chat_step = None - self.events = [ ] - self.mark = "?" - - def add_mark(self, value): - if value not in ["Fix", "Partial", "None", "?"]: - print(f"answer must be in { ['Fix', 'Partial', 'None', '?'] }") + if capture_streams: + self.stdout_wrapper = CopyingTextIOWrapper(sys.stdout) + self.stderr_wrapper = CopyingTextIOWrapper(sys.stderr) + sys.stdout = self.stdout_wrapper + sys.stderr = self.stdout_wrapper else: - self.meta["mark"] = value + self.stderr_wrapper = None + self.stderr_wrapper = None + self.chat_step = None + atexit.register(lambda: self.dump()) def total(self, key): return sum( - [x["stats"][key] for x in self.steps if x["output"]["type"] == "chat"] + getattr(x.stats, key) for x in self.steps if x.output.type == "chat" and x.stats is not None ) def dump(self): - self.meta["total_tokens"] = self.total("tokens") - self.meta["total_time"] = self.total("time") - self.meta["total_cost"] = self.total("cost") - - full_json = [ - { - "meta": self.meta, - "steps": self.steps, - "instructions": self._instructions, - "stdout": self.stdout_wrapper.getvalue(), - "stderr": self.stderr_wrapper.getvalue(), - "events" : self.events - } - ] - - print(f"*** Write ChatDBG log to {self.log_file}") - with open(self.log_file, "a") as file: - yaml.dump(full_json, file, default_flow_style=False) - - def log(self, event_json): - self.events += [ event_json ] + self.meta.total_tokens = self.total("tokens") + self.meta.total_time = self.total("time") + self.meta.total_cost = self.total("cost") + + full_log = Log( + meta=self.meta, + steps=self.steps, + instructions=self._instructions, + stdout=self.stdout_wrapper.getvalue(), + stderr=self.stderr_wrapper.getvalue(), + events=self.events + ) + + print(f"*** Write ChatDBG log to {self.log_filename}") + with open(self.log_filename, "a") as file: + def literal_presenter(dumper, data): + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + + yaml.add_representer(str, literal_presenter) + yaml.dump(full_log, file, default_flow_style=False, indent=4) def instructions(self, instructions): self._instructions = instructions def user_command(self, line, output): - if self.chat_step != None: + if self.chat_step is not None: x = self.chat_step self.chat_step = None else: - x = {"input": line, "output": {"type": "text", "output": output}} + x = Step(input=line, output=TextOutput(output=output)) self.steps.append(x) def push_chat(self, line, full_prompt): - self.chat_step = { - "input": line, - "full_prompt": full_prompt, - "output": {"type": "chat", "outputs": []}, - "stats": {"tokens": 0, "cost": 0, "time": 0}, - } + self.chat_step = Step( + input=line, + full_prompt=full_prompt, + output=ChatOutput() + ) def pop_chat(self, stats): - self.chat_step["stats"] = stats - - def message(self, text): - self.chat_step["output"]["outputs"].append({"type": "text", "output": text}) - - def function(self, line, output): - x = { - "type": "call", - "input": line, - "output": {"type": "text", "output": output}, - } - self.chat_step["output"]["outputs"].append(x) - - -# Custom representer for literal scalar representation -def literal_presenter(dumper, data): - if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - -yaml.add_representer(str, literal_presenter) + if self.chat_step is not None: + self.chat_step.stats = Stats(**stats) From 603a62b1e0866a60d01b54d3de50850f528803ef Mon Sep 17 00:00:00 2001 From: Stephen Freund Date: Mon, 25 Mar 2024 16:12:02 -0400 Subject: [PATCH 07/17] more logging --- src/chatdbg/assistant/assistant.py | 54 +++--- src/chatdbg/chatdbg_pdb.py | 27 +-- src/chatdbg/ipdb_util/chatlog.py | 273 ++++++++++++++++++----------- 3 files changed, 209 insertions(+), 145 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 815a117..7cc0e2d 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -6,45 +6,35 @@ import openai import textwrap -import abc import textwrap import sys -class AbstractAssistantClient(abc.ABC): +class AbstractAssistantClient: - @abc.abstractmethod def begin_dialog(self, instructions): pass - @abc.abstractmethod def end_dialog(self): pass - @abc.abstractmethod - def begin_query(self, user_prompt): + def begin_query(self, prompt, **kwargs): pass - @abc.abstractmethod def end_query(self, stats): pass - @abc.abstractmethod def warn(self, text): pass - @abc.abstractmethod def fail(self, text): pass - @abc.abstractmethod def stream(self, event, text): pass - @abc.abstractmethod def response(self, text): pass - @abc.abstractmethod def function_call(self, call, result): pass @@ -64,7 +54,7 @@ def stream(self, event, text): # begin / none, step / delta , complete / full pass - def begin_query(self, user_prompt): + def begin_query(self, prompt, **kwargs): pass def end_query(self, stats): @@ -112,38 +102,42 @@ def __init__( self._conversation = [{"role": "system", "content": instructions}] self._max_call_response_tokens = max_call_response_tokens - self.check_model() + self._check_model() + self._broadcast('begin_dialog', instructions) - def broadcast(self, method_name, *args, **kwargs): + def close(self): + self._broadcast('end_dialog') + + def _broadcast(self, method_name, *args, **kwargs): for client in self._clients: method = getattr(client, method_name, None) if callable(method): method(*args, **kwargs) - def check_model(self): + def _check_model(self): result = litellm.validate_environment(self._model) missing_keys = result["missing_keys"] if missing_keys != []: _, provider, _, _ = litellm.get_llm_provider(self._model) if provider == 'openai': - self.broadcast('fail', textwrap.dedent(f"""\ + self._broadcast('fail', textwrap.dedent(f"""\ You need an OpenAI key to use the {self._model} model. You can get a key here: https://platform.openai.com/api-keys. Set the environment variable OPENAI_API_KEY to your key value.""")) sys.exit(1) else: - self.broadcast('fail', textwrap.dedent(f"""\ + self._broadcast('fail', textwrap.dedent(f"""\ You need to set the following environment variables to use the {self._model} model: {', '.join(missing_keys)}""")) sys.exit(1) if not litellm.supports_function_calling(self._model): - self.broadcast('fail', textwrap.dedent(f"""\ + self._broadcast('fail', textwrap.dedent(f"""\ The {self._model} model does not support function calls. You must use a model that does, eg. gpt-4.""")) sys.exit(1) - def sandwhich_tokens(self, text: str, max_tokens: int, top_proportion: float) -> str: + def _sandwhich_tokens(self, text: str, max_tokens: int, top_proportion: float) -> str: model = self._model if max_tokens == None: return text @@ -180,7 +174,7 @@ def _make_call(self, tool_call) -> str: function = self._functions[name] call = function["schema"]["format"].format_map(args) result = function["function"](**args) - self.broadcast('function_call', call, result) + self._broadcast('function_call', call, result) except OSError as e: # function produced some error -- move this to client??? result = f"Error: {e}" @@ -191,12 +185,14 @@ def _make_call(self, tool_call) -> str: def query( self, - prompt: str + prompt: str, + **kwargs ): start = time.time() cost = 0 try: + self._broadcast("begin_query", prompt, kwargs) self._conversation.append({"role": "user", "content": prompt}) while True: @@ -218,14 +214,14 @@ def query( self._conversation.append(response_message) if response_message.content: - self.broadcast('response', '(Message) ' + response_message.content) + self._broadcast('response', '(Message) ' + response_message.content) if completion.choices[0].finish_reason == 'tool_calls': tool_calls = response_message.tool_calls try: for tool_call in tool_calls: function_response = self._make_call(tool_call) - function_response = self.sandwhich_tokens( + function_response = self._sandwhich_tokens( function_response, self._max_call_response_tokens, 0.5) @@ -236,15 +232,15 @@ def query( "content": function_response, } self._conversation.append(response) - self.broadcast('response', '') + self._broadcast('response', '') except Exception as e: # Warning: potential infinite loop. - self.broadcast('warn', f"Error processing tool calls: {e}") + self._broadcast('warn', f"Error processing tool calls: {e}") else: break elapsed = time.time() - start - return { + stats = { "cost": cost, "time": elapsed, "model": self._model, @@ -252,8 +248,10 @@ def query( "prompt_tokens": completion.usage.prompt_tokens, "completion_tokens": completion.usage.completion_tokens, } + self._broadcast("end_query", stats) + return stats except openai.OpenAIError as e: - self.broadcast('fail', f"Internal Error: {e.__dict__}") + self._broadcast('fail', f"Internal Error: {e.__dict__}") sys.exit(1) if __name__ == '__main__': diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index c00eaa9..897eaa8 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -1,3 +1,4 @@ +import atexit import ast import inspect import linecache @@ -73,6 +74,8 @@ def __init__(self, *args, **kwargs): self._chat_prefix = " " self._text_width = 80 self._assistant = None + atexit.register(lambda: self._close_assistant()) + self._history = [] self._error_specific_prompt = "" @@ -103,7 +106,13 @@ def __init__(self, *args, **kwargs): # set this to True ONLY AFTER we have had access to stack frames self._show_locals = False - self._log = ChatDBGLog(chatdbg_config) + self._log = ChatDBGLog(log_filename=chatdbg_config.log, + config=chatdbg_config.to_json(), + capture_streams=True) + + def _close_assistant(self): + if self._assistant != None: + self._assistant.close() def _is_user_frame(self, frame): if not self._is_user_file(frame.f_code.co_filename): @@ -201,7 +210,8 @@ def onecmd(self, line: str) -> bool: finally: self.stdout = hist_file.getfile() output = strip_color(hist_file.getvalue()) - self._log.user_command(line, output) + if not self.was_chat: + self._log.function_call(line, output) if ( line.split(' ')[0] not in ["hist", "test_prompt", "c", "cont", "continue", "config"] and not self.was_chat @@ -563,9 +573,7 @@ def do_chat(self, arg): if self._assistant == None: self._make_assistant() - self._log.push_chat(arg, full_prompt) - stats = self._assistant.query(full_prompt) - self._log.pop_chat(stats) + stats = self._assistant.query(full_prompt, user_text=arg) self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") @@ -610,7 +618,6 @@ def info(value): """ command = f"info {value}" result = self._capture_onecmd(command) - self._log.function(command, result) return truncate_proportionally(result, top_proportion=1) def debug(command): @@ -636,7 +643,6 @@ def debug(command): cmd = command if command != "list" else "ll" # old_curframe = self.curframe result = self._capture_onecmd(cmd) - self._log.function(command, result) # help the LLM know where it is... # if old_curframe != self.curframe: @@ -666,11 +672,9 @@ def slice(name): """ command = f"slice {name}" result = self._capture_onecmd(command) - self._log.function(command, result) return truncate_proportionally(result, top_proportion=0.5) instruction_prompt = self._ip_instructions() - self._log.instructions(instruction_prompt) if chatdbg_config.take_the_wheel: functions = [ debug, info ] @@ -687,7 +691,8 @@ def slice(name): clients=[ ChatAssistantClient(self.stdout, self.prompt, self._chat_prefix, - self._text_width) ] + self._text_width), + self._log ] ) @@ -706,7 +711,7 @@ def __init__(self, out, debugger_prompt, chat_prefix, width): # Call backs - def begin_query(self, user_prompt): + def begin_query(self, user_text, full_prompt): pass def end_query(self, stats): diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index 4fabf04..d034b85 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -6,52 +6,9 @@ import yaml from pydantic import BaseModel from typing import List, Union, Optional +from ..assistant.assistant import AbstractAssistantClient -class Output(BaseModel): - type: str - output: Optional[str] = None -class FunctionOutput(Output): - type: str = "call" - -class TextOutput(Output): - type: str = "text" - -class ChatOutput(BaseModel): - type: str = "chat" - outputs: List[Output] = [] - -class Stats(BaseModel): - tokens: int = 0 - cost: float = 0.0 - time: float = 0.0 - - class Config: - extra = 'allow' - - -class Step(BaseModel): - input: str - output: Union[TextOutput, ChatOutput] - full_prompt: Optional[str] = None - stats: Optional[Stats] = None - -class Meta(BaseModel): - time: datetime - command_line: str - uid: str - config: str - mark: str = "?" - total_tokens: int = 0 - total_time: float = 0.0 - total_cost: float = 0.0 - -class Log(BaseModel): - meta: Meta - steps: List[Step] - instructions: str - stdout: Optional[str] - stderr: Optional[str] class CopyingTextIOWrapper: """ @@ -76,74 +33,178 @@ def __getattr__(self, attr): # Delegate attribute access to the file object return getattr(self.file, attr) - -class ChatDBGLog: +# class Output(BaseModel): +# type: str + +# class TextOutput(BaseModel): +# type: str = "text" +# output: str + +# class ChatOutput(BaseModel): +# type:str = 'chat' +# outputs: List[Output] = [] + +# class Function(Output): +# type: str = 'call' +# input: str +# output: TextOutput + +# class Stats(BaseModel): +# tokens: int = 0 +# cost: float = 0.0 +# time: float = 0.0 + +# class Config: +# extra = 'allow' + +# class Chat(BaseModel): +# input: str +# output: ChatOutput = ChatOutput() +# prompt: Optional[str] = None +# stats: Optional[Stats] = None + +# def append(self, s: Output): +# self.output.outputs.append(s) + +# class Meta(BaseModel): +# time: datetime +# command_line: str +# uid: str +# config: dict +# mark: str = "?" +# total_tokens: int = 0 +# total_time: float = 0.0 +# total_cost: float = 0.0 + +# class Log(BaseModel): +# meta: Meta +# steps: List[Function | Chat] = [] +# current_chat: Optional[Chat] = None +# instructions: Optional[str] +# stdout: Optional[str] +# stderr: Optional[str] + +# class Config: +# exclude = {'current_chat'} + +# def append(self, s: Function | Chat): +# self.steps.append(s) + +# def total(self, key): +# return sum( +# getattr(x.stats, key) for x in self.steps if isinstance(x.output, ChatOutput) and x.stats is not None +# ) + +# def model_dump_json(self, **kwargs): +# self.meta.total_tokens = self.total("tokens") +# self.meta.total_time = self.total("time") +# self.meta.total_cost = self.total("cost") +# return super().model_dump_json(kwargs) + +class ChatDBGLog(AbstractAssistantClient): def __init__(self, log_filename, config, capture_streams=True): - self.meta = Meta( - time=datetime.now(), - command_line=" ".join(sys.argv), - uid=str(uuid.uuid4()), - config=config - ) - self.steps = [] - self.log_filename = log_filename - self._instructions = "" + self._log_filename = log_filename + self.config = config if capture_streams: - self.stdout_wrapper = CopyingTextIOWrapper(sys.stdout) - self.stderr_wrapper = CopyingTextIOWrapper(sys.stderr) - sys.stdout = self.stdout_wrapper - sys.stderr = self.stdout_wrapper + self._stdout_wrapper = CopyingTextIOWrapper(sys.stdout) + self._stderr_wrapper = CopyingTextIOWrapper(sys.stderr) + sys.stdout = self._stdout_wrapper + sys.stderr = self._stdout_wrapper else: - self.stderr_wrapper = None - self.stderr_wrapper = None - self.chat_step = None - atexit.register(lambda: self.dump()) - - def total(self, key): - return sum( - getattr(x.stats, key) for x in self.steps if x.output.type == "chat" and x.stats is not None - ) - - def dump(self): - self.meta.total_tokens = self.total("tokens") - self.meta.total_time = self.total("time") - self.meta.total_cost = self.total("cost") - - full_log = Log( - meta=self.meta, - steps=self.steps, - instructions=self._instructions, - stdout=self.stdout_wrapper.getvalue(), - stderr=self.stderr_wrapper.getvalue(), - events=self.events - ) - - print(f"*** Write ChatDBG log to {self.log_filename}") - with open(self.log_filename, "a") as file: + self._stderr_wrapper = None + self._stderr_wrapper = None + + meta = { + 'time': datetime.now(), + 'command_line': " ".join(sys.argv), + 'uid': str(uuid.uuid4()), + 'config': self.config + } + log = { + 'steps':[], + 'meta':meta, + 'instructions':None, + 'stdout':self._stdout_wrapper.getvalue(), + 'stderr':self._stderr_wrapper.getvalue(), + } + self._current_chat = None + self._log = log + + def _dump(self): + log = self._log + + def total(key): + return sum( + x['stats'][key] for x in log['steps'] if x['output']['type'] == 'chat' and 'stats' in x['output'] + ) + + log['meta']['total_tokens'] = total("tokens") + log['meta']['total_time'] = total("time") + log['meta']['total_cost'] = total("cost") + + print(f"*** Writing ChatDBG dialog log to {self._log_filename}") + + with open(self._log_filename, "a") as file: def literal_presenter(dumper, data): - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) yaml.add_representer(str, literal_presenter) - yaml.dump(full_log, file, default_flow_style=False, indent=4) + yaml.dump([ log ], file, default_flow_style=False, indent=2) + + def begin_dialog(self, instructions): + log = self._log + assert log != None + log['instructions'] = instructions + + def end_dialog(self): + if self._log != None: + self._dump() + self._log = None + + def begin_query(self, prompt, kwargs): + log = self._log + assert log != None + assert self._current_chat == None + self._current_chat = { + 'input':kwargs['user_text'], + 'prompt':prompt, + 'output': { 'type': 'chat', 'outputs': []} + } + + def end_query(self, stats): + log = self._log + assert log != None + assert self._current_chat != None + log['steps'] += [ self._current_chat ] + self._current_chat = None + + def _post(self, text, kind): + log = self._log + assert log != None + if self._current_chat != None: + self._current_chat['output']['outputs'].append({ 'type': 'text', 'output': f"*** {kind}: {text}"}) + else: + log['steps'].append({ 'type': 'call', 'input': f"*** {kind}", 'output': { 'type': 'text', 'output': text } }) + + def warn(self, text): + self._post(text, "Warning") + + def fail(self, text): + self._post(text, "Failure") - def instructions(self, instructions): - self._instructions = instructions + def response(self, text): + log = self._log + assert log != None + assert self._current_chat != None + self._current_chat['output']['outputs'].append({ 'type': 'text', 'output': text}) - def user_command(self, line, output): - if self.chat_step is not None: - x = self.chat_step - self.chat_step = None + def function_call(self, call, result): + log = self._log + assert log != None + if self._current_chat != None: + self._current_chat['output']['outputs'].append({ 'type': 'call', 'input': call, 'output': { 'type': 'text', 'output': result }}) else: - x = Step(input=line, output=TextOutput(output=output)) - self.steps.append(x) - - def push_chat(self, line, full_prompt): - self.chat_step = Step( - input=line, - full_prompt=full_prompt, - output=ChatOutput() - ) - - def pop_chat(self, stats): - if self.chat_step is not None: - self.chat_step.stats = Stats(**stats) + log['steps'].append({ 'type': 'call', 'input': call, 'output': { 'type': 'text', 'output': result }}) From d03b952b0a8e09f7060f0550a9a34aef98418bec Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Tue, 26 Mar 2024 09:43:52 -0400 Subject: [PATCH 08/17] streaming wip --- src/chatdbg/assistant/assistant.py | 466 +++++++--------------------- src/chatdbg/chatdbg_pdb.py | 52 +--- src/chatdbg/ipdb_util/chatlog.py | 67 ---- src/chatdbg/ipdb_util/prompts.py | 9 - src/chatdbg/ipdb_util/streamwrap.py | 20 +- src/chatdbg/ipdb_util/text.py | 26 +- 6 files changed, 154 insertions(+), 486 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 7cc0e2d..f25f3f7 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -1,6 +1,5 @@ import json import time -from typing import Callable import sys import litellm import openai @@ -29,7 +28,11 @@ def warn(self, text): def fail(self, text): pass - def stream(self, event, text): + def begin_stream(self): + pass + def stream_delta(self, text): + pass + def end_stream(self): pass def response(self, text): @@ -39,7 +42,7 @@ def function_call(self, call, result): pass -class PrintintAssistantClient(AbstractAssistantClient): +class PrintingAssistantClient(AbstractAssistantClient): def __init__(self, out=sys.stdout): self.out = out @@ -50,9 +53,14 @@ def fail(self, text): print(textwrap.indent(text, '*** '), file=self.out) sys.exit(1) - def stream(self, event, text): - # begin / none, step / delta , complete / full - pass + def begin_stream(self): + print("<<<", file=self.out) + + def stream_delta(self, text): + print(text, end='', file=self.out) + + def end_stream(self): + print(">>>", file=self.out) def begin_query(self, prompt, **kwargs): pass @@ -80,7 +88,7 @@ def __init__( instructions, model="gpt-3.5-turbo-1106", timeout=30, - clients = [ PrintintAssistantClient() ], + clients = [ PrintingAssistantClient() ], functions=[], max_call_response_tokens=4096, debug=False, @@ -186,27 +194,32 @@ def _make_call(self, tool_call) -> str: def query( self, prompt: str, + stream = False, **kwargs ): start = time.time() cost = 0 try: - self._broadcast("begin_query", prompt, kwargs) + self._broadcast("begin_query", prompt, **kwargs) self._conversation.append({"role": "user", "content": prompt}) while True: self._conversation = litellm.utils.trim_messages(self._conversation, self._model) - completion = litellm.completion( - model=self._model, - messages=self._conversation, - tools=[ - {"type": "function", "function": f["schema"]["schema"]} - for f in self._functions.values() - ], - timeout=self._timeout, - logger_fn=self._logger - ) + + if stream: + self._broadcast('begin_stream') + completion_stream = self.completion(stream) + chunks = [] + for chunk in completion_stream: + print(chunk) + self._broadcast('stream_delta', chunk.choices[0].delta.content or "") + chunks.append(chunk) + completion = litellm.stream_chunk_builder(chunks, messages=self._conversation) + print('---', completion) + self._broadcast('end_stream') + else: + completion = self.completion(stream) cost += litellm.completion_cost(completion) @@ -217,36 +230,105 @@ def query( self._broadcast('response', '(Message) ' + response_message.content) if completion.choices[0].finish_reason == 'tool_calls': - tool_calls = response_message.tool_calls - try: - for tool_call in tool_calls: - function_response = self._make_call(tool_call) - function_response = self._sandwhich_tokens( + self._add_function_results_to_conversation(response_message) + else: + break + + elapsed = time.time() - start + stats = { + "cost": cost, + "time": elapsed, + "model": self._model, + "tokens": completion.usage.total_tokens, + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, + } + self._broadcast("end_query", stats) + return stats + except openai.OpenAIError as e: + self._broadcast('fail', f"Internal Error: {e.__dict__}") + sys.exit(1) + + def completion(self, stream): + completion = litellm.completion( + model=self._model, + messages=self._conversation, + tools=[ + {"type": "function", "function": f["schema"]["schema"]} + for f in self._functions.values() + ], + timeout=self._timeout, + logger_fn=self._logger, + stream=stream + ) + + return completion + + def _add_function_results_to_conversation(self, response_message): + tool_calls = response_message.tool_calls + try: + for tool_call in tool_calls: + function_response = self._make_call(tool_call) + function_response = self._sandwhich_tokens( function_response, self._max_call_response_tokens, 0.5) - response = { + response = { "tool_call_id": tool_call.id, "role": "tool", "name": tool_call.function.name, "content": function_response, } - self._conversation.append(response) - self._broadcast('response', '') - except Exception as e: + self._conversation.append(response) + self._broadcast('response', '') + except Exception as e: # Warning: potential infinite loop. - self._broadcast('warn', f"Error processing tool calls: {e}") - else: - break + self._broadcast('warn', f"Error processing tool calls: {e}") + + def query2( + self, + prompt: str, + stream: bool = False, + **kwargs + ): + start = time.time() + cost = 0 + + try: + self._broadcast("begin_query", prompt, **kwargs) + self._conversation.append({"role": "user", "content": prompt}) + + while True: + self._conversation = litellm.utils.trim_messages(self._conversation, self._model) + response = litellm.completion( + model=self._model, + messages=self._conversation, + tools=[ + {"type": "function", "function": f["schema"]["schema"]} + for f in self._functions.values() + ], + timeout=self._timeout, + logger_fn=self._logger, + stream=True + ) + + chunks=[] + for chunk in response: + print(chunk) + chunks.append(chunk) + print('---\n',litellm.stream_chunk_builder(chunks, messages=self._conversation)) + + break + elapsed = time.time() - start stats = { "cost": cost, "time": elapsed, "model": self._model, - "tokens": completion.usage.total_tokens, - "prompt_tokens": completion.usage.prompt_tokens, - "completion_tokens": completion.usage.completion_tokens, + # "tokens": completion.usage.total_tokens, + # "prompt_tokens": completion.usage.prompt_tokens, + # "completion_tokens": completion.usage.completion_tokens, } self._broadcast("end_query", stats) return stats @@ -254,6 +336,7 @@ def query( self._broadcast('fail', f"Internal Error: {e.__dict__}") sys.exit(1) + if __name__ == '__main__': def weather(location): """ @@ -287,319 +370,8 @@ def weather(location): return "Sunny and 72 degrees." a = Assistant("You generate text.") - a.add_function(weather) - x = a.query("What's the weather in Boston?") + a._add_function(weather) + # x = a.query("tell me what model you are before making any function calls. And what's the weather in Boston?", stream=True) + x = a.query("What's the weather in Boston?", stream=True) print(x) - -# import atexit -# import textwrap -# import json -# import time -# import sys - -# import llm_utils -# from openai import * -# from openai import AssistantEventHandler -# from pydantic import BaseModel - -# class AssistantPrinter: -# def begin_stream(self): -# print('\n', flush=True) - -# def stream(self, text=''): -# print(text, flush=True, end='') - -# def end_stream(self): -# print('\n', flush=True) - -# def complete_message(self, text=''): -# print(text, flush=True) - -# def log(self, json_obj): -# pass - -# def fail(self, message='Failed'): -# print() -# print(textwrap.wrap(message, width=70, initial_indent='*** ')) -- wrap then indent -# sys.exit(1) - -# def warn(self, message='Warning'): -# print() -# print(textwrap.wrap(message, width=70, initial_indent='*** ')) -- wrap then indent - - -# class Assistant: -# """ -# An Assistant is a wrapper around OpenAI's assistant API. Example usage: - -# assistant = Assistant("Assistant Name", instructions, -# model='gpt-4-1106-preview', debug=True) -# assistant.add_function(my_func) -# response = assistant.run(user_prompt) - -# Name can be any name you want. - -# If debug is True, it will create a log of all messages and JSON responses in -# json.txt. -# """ - -# def __init__( -# self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, -# printer = AssistantPrinter()): -# self.printer = printer -# try: -# self.client = OpenAI(timeout=timeout) -# except OpenAIError: -# self.printer.fail("""\ -# You need an OpenAI key to use this tool. -# You can get a key here: https://platform.openai.com/api-keys. -# Set the environment variable OPENAI_API_KEY to your key value.""") - -# self.assistants = self.client.beta.assistants -# self.threads = self.client.beta.threads -# self.functions = dict() - -# self.assistant = self.assistants.create( -# name=name, instructions=instructions, model=model -# ) -# self.thread = self.threads.create() - -# atexit.register(self._delete_assistant) - -# def _delete_assistant(self): -# if self.assistant != None: -# try: -# id = self.assistant.id -# response = self.assistants.delete(id) -# assert response.deleted -# except OSError: -# raise -# except Exception as e: -# self.printer.warn(f"Assistant {id} was not deleted ({e}). You can do so at https://platform.openai.com/assistants.") - -# def add_function(self, function): -# """ -# Add a new function to the list of function tools for the assistant. -# The function should have the necessary json spec as its pydoc string. -# """ -# function_json = json.loads(function.__doc__) -# try: -# name = function_json["name"] -# self.functions[name] = function - -# tools = [ -# {"type": "function", "function": json.loads(function.__doc__)} -# for function in self.functions.values() -# ] - -# self.assistants.update(self.assistant.id, tools=tools) -# except OpenAIError as e: -# self.printer.fail(f"OpenAI Error: {e}") - -# def _make_call(self, tool_call): -# name = tool_call.function.name -# args = tool_call.function.arguments - -# # There is a sketchy case that happens occasionally because -# # the API produces a bad call... -# try: -# args = json.loads(args) -# function = self.functions[name] -# result = function(**args) -# except OSError as e: -# result = f"Error: {e}" -# except Exception as e: -# result = f"Ill-formed function call: {e}" -# return result - -# def drain_stream(self, stream): -# run = None -# for event in stream: -# if event.event not in [ 'thread.message.delta', 'thread.run.step.delta' ]: -# print(event.event) -# self.printer.log(event) -# if event.event in [ 'thread.run.created', 'thread.run.in_progress' ]: -# run = event.data -# elif event.event == 'thread.run.completed': -# self.printer.end_stream() -# return event.data -# elif event.event == 'thread.message.delta': -# self.printer.stream(event.data.delta.content[0].text.value) -# elif event.event == 'thread.message.completed': -# self.printer.complete_message(event.data.content[0].text.value) -# elif event.event == 'thread.run.requires_action': -# r = event.data -# if r.status == "requires_action": -# outputs = [] -# self.printer.end_stream() -# for tool_call in r.required_action.submit_tool_outputs.tool_calls: -# output = self._make_call(tool_call) -# outputs += [{"tool_call_id": tool_call.id, "output": output}] -# self.printer.begin_stream() -# try: -# with self.threads.runs.submit_tool_outputs( -# thread_id=self.thread.id, run_id=r.id, tool_outputs=outputs, stream=True -# ) as new_stream: -# _ = self.drain_stream(new_stream) -# except OSError as e: -# raise -# except Exception as e: -# # silent failure because the tool call submit biffed. Not muchw e can do -# pass -# elif event.event == 'thread.run.failed': -# run = event.data -# self.printer.fail(f"Internal Failure ({run.last_error.code}): {run.last_error.message}") -# elif event.event == 'error': -# self.printer.fail(f"Internal Failure:** {event.data}") -# print('***', run) -# return run - -# def run(self, prompt): -# """ -# Give the prompt to the assistant and get the response, which may included -# intermediate function calls. -# All output is printed to the given file. -# """ - -# if self.assistant == None: -# return { -# "tokens": 0, -# "prompt_tokens": 0, -# "completion_tokens": 0, -# "model": self.assistant.model, -# "cost": 0, -# } - -# start_time = time.perf_counter() - -# assert len(prompt) <= 32768 -# self.threads.messages.create( -# thread_id=self.thread.id, role="user", content=prompt -# ) - -# class EventHandler(AssistantEventHandler): -# def on_event(self, event): -# print(event.event) - -# with self.threads.runs.create_and_stream( -# thread_id=self.thread.id, -# assistant_id=self.assistant.id, -# # stream=True -# event_handler=EventHandler(), -# ) as stream: -# self.drain_stream(stream) - -# end_time = time.perf_counter() -# elapsed_time = end_time - start_time - -# cost = llm_utils.calculate_cost( -# run.usage.prompt_tokens, -# run.usage.completion_tokens, -# self.assistant.model, -# ) -# return { -# "tokens": run.usage.total_tokens, -# "prompt_tokens": run.usage.prompt_tokens, -# "completion_tokens": run.usage.completion_tokens, -# "model": self.assistant.model, -# "cost": cost, -# "time": elapsed_time, -# "thread.id": self.thread.id, -# "thread": self.thread, -# "run.id": run.id, -# "run": run, -# "assistant.id": self.assistant.id, -# } - -# return {} - -# if __name__ == '__main__': -# def weather(location): -# """ -# { -# "name": "get_weather", -# "description": "Determine weather in my location", -# "parameters": { -# "type": "object", -# "properties": { -# "location": { -# "type": "string", -# "description": "The city and state e.g. San Francisco, CA" -# }, -# "unit": { -# "type": "string", -# "enum": [ -# "c", -# "f" -# ] -# } -# }, -# "required": [ -# "location" -# ] -# } -# } -# """ -# return "Sunny and 72 degrees." - -# a = Assistant("Test", "You generate text.") -# a.add_function(weather) -# x = a.run("What's the weather in Boston?") -# print(x) - - - # def _print_message( - # self, message, indent, append_message: Callable[[str], None], wrap=120 - # ) -> None: - # def _format_message(indent) -> str: - # tool_calls = None - # if "tool_calls" in message: - # tool_calls = message["tool_calls"] - # elif hasattr(message, "tool_calls"): - # tool_calls = message.tool_calls - - # content = None - # if "content" in message: - # content = message["content"] - # elif hasattr(message, "content"): - # content = message.content - - # assert content != None or tool_calls != None - - # # The longest role string is 'assistant'. - # max_role_length = 9 - # # We add 3 characters for the brackets and space. - # subindent = indent + max_role_length + 3 - - # role = message["role"].upper() - # role_indent = max_role_length - len(role) - - # output = "" - - # if content != None: - # content = llm_utils.word_wrap_except_code_blocks( - # content, wrap - len(role) - indent - 3 - # ) - # first, *rest = content.split("\n") - # output += f"{' ' * indent}[{role}]{' ' * role_indent} {first}\n" - # for line in rest: - # output += f"{' ' * subindent}{line}\n" - - # if tool_calls != None: - # if content != None: - # output += f"{' ' * subindent} Function calls:\n" - # else: - # output += ( - # f"{' ' * indent}[{role}]{' ' * role_indent} Function calls:\n" - # ) - # for tool_call in tool_calls: - # arguments = json.loads(tool_call.function.arguments) - # output += f"{' ' * (subindent + 4)}{tool_call.function.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})\n" - # return output - - # append_message(_format_message(indent)) - # if self._log: - # print(_format_message(0), file=self._log) - - diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 897eaa8..86a5ebc 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -13,7 +13,6 @@ from pprint import pprint import IPython -import llm_utils from traitlets import TraitError from chatdbg.ipdb_util.capture import CaptureInput @@ -23,7 +22,7 @@ from .ipdb_util.config import Chat from .ipdb_util.locals import * from .ipdb_util.prompts import pdb_instructions -from .ipdb_util.streamwrap import StreamTextWrapper +from .ipdb_util.streamwrap import StreamingTextWrapper from .ipdb_util.text import * chatdbg_config: Chat = None @@ -733,7 +732,7 @@ def stream(self, event, text): def response(self, text): if text != None: - text = llm_utils.utils.word_wrap_except_code_blocks(text, self.width-len(self.chat_prefix)) + text = word_wrap_except_code_blocks(text, self.width-len(self.chat_prefix)) self._print(text) def function_call(self, call, result): @@ -743,50 +742,3 @@ def function_call(self, call, result): entry = f"{self.debugger_prompt}{call}" self._print(entry) - - - -# class ChatAssistantOutput: -# def __init__(self, stdout, prefix, width, chat_log, stream_response): -# self.stdout = stdout -# self.chat_log = chat_log -# self.prefix = prefix -# self.width = width -# if stream_response and False: -# self.streamer = StreamTextWrapper(indent=self.prefix, width=self.width) -# else: -# self.streamer = None - -# def begin_stream(self): -# if self.streamer: -# print(file=self.stdout) - -# def stream(self, text=''): -# if self.streamer: -# print(self.streamer.add(text), file=self.stdout, flush=True, end='') - -# def end_stream(self): -# if self.streamer: -# print(self.streamer.flush(), file=self.stdout) - -# def complete_message(self, text=''): -# line = llm_utils.word_wrap_except_code_blocks(text, self.width - 5) -# self.log.message(line) -# if self.streamer: -# print(self.streamer.add('', flush=True), file=self.stdout, flush=True, end='') -# else: -# line = textwrap.indent(line, self.prefix, lambda _: True) -# print(line, file=self.stdout, flush=True) - -# def log(self, json_obj): -# if chatdbg_config.debug: -# self.chat_log.log(json_obj) - -# def fail(self, message='Failed'): -# print(file=self.stdout) -# print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) -# sys.exit(1) - -# def warn(self, message='Warning'): -# print(file=self.stdout) -# print(textwrap.wrap(message, width=70, initial_indent='*** '),file=self.stdout) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index d034b85..8a169f9 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -33,73 +33,6 @@ def __getattr__(self, attr): # Delegate attribute access to the file object return getattr(self.file, attr) -# class Output(BaseModel): -# type: str - -# class TextOutput(BaseModel): -# type: str = "text" -# output: str - -# class ChatOutput(BaseModel): -# type:str = 'chat' -# outputs: List[Output] = [] - -# class Function(Output): -# type: str = 'call' -# input: str -# output: TextOutput - -# class Stats(BaseModel): -# tokens: int = 0 -# cost: float = 0.0 -# time: float = 0.0 - -# class Config: -# extra = 'allow' - -# class Chat(BaseModel): -# input: str -# output: ChatOutput = ChatOutput() -# prompt: Optional[str] = None -# stats: Optional[Stats] = None - -# def append(self, s: Output): -# self.output.outputs.append(s) - -# class Meta(BaseModel): -# time: datetime -# command_line: str -# uid: str -# config: dict -# mark: str = "?" -# total_tokens: int = 0 -# total_time: float = 0.0 -# total_cost: float = 0.0 - -# class Log(BaseModel): -# meta: Meta -# steps: List[Function | Chat] = [] -# current_chat: Optional[Chat] = None -# instructions: Optional[str] -# stdout: Optional[str] -# stderr: Optional[str] - -# class Config: -# exclude = {'current_chat'} - -# def append(self, s: Function | Chat): -# self.steps.append(s) - -# def total(self, key): -# return sum( -# getattr(x.stats, key) for x in self.steps if isinstance(x.output, ChatOutput) and x.stats is not None -# ) - -# def model_dump_json(self, **kwargs): -# self.meta.total_tokens = self.total("tokens") -# self.meta.total_time = self.total("time") -# self.meta.total_cost = self.total("cost") -# return super().model_dump_json(kwargs) class ChatDBGLog(AbstractAssistantClient): def __init__(self, log_filename, config, capture_streams=True): diff --git a/src/chatdbg/ipdb_util/prompts.py b/src/chatdbg/ipdb_util/prompts.py index c3c3cfc..f2c72bd 100644 --- a/src/chatdbg/ipdb_util/prompts.py +++ b/src/chatdbg/ipdb_util/prompts.py @@ -14,15 +14,6 @@ contribute to the error. """ -# _info_function="""\ -# Call the `info` function to get the documentation and source code for any -# function or method reference visible in the current frame. The argument to -# info is a function name or a method reference. - -# Unless it is from a common, widely-used library, you MUST call `info` exactly once on any -# function or method reference that is called in code leading up to the error, that appears -# in the argument list for a function call in the code, or that appears on the call stack. -# """ _info_function = """\ Call the `info` function to get the documentation and source code for any diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py index f8faf2a..3eb74bd 100644 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -1,15 +1,15 @@ import textwrap import re import sys -from llm_utils import word_wrap_except_code_blocks +from .text import word_wrap_except_code_blocks -class StreamTextWrapper: +class StreamingTextWrapper: def __init__(self, indent=' ', width=70): - self.buffer = '' - self.wrapped = '' - self.pending = '' + self.buffer = '' # the raw text so far + self.wrapped = '' # the successfully wrapped text do far + self.pending = '' # the part after the last space in buffer -- has not been wrapped yet self.indent = indent self.width = width @@ -22,17 +22,13 @@ def add(self, text, flush=False): self.pending = text_bits[-1] self.buffer += (''.join(text_bits[0:-1])) - # print('---', self.buffer, '---', self.pending) wrapped = word_wrap_except_code_blocks(self.buffer, self.width) wrapped = textwrap.indent(wrapped, self.indent, lambda _: True) - printable_part = wrapped[len(self.wrapped):] + wrapped_delta = wrapped[len(self.wrapped):] self.wrapped = wrapped - return printable_part + return wrapped_delta def flush(self): - # if self.pending == '': - # return None - # else: result = self.add('', flush=True) self.buffer = '' self.wrapped = '' @@ -41,7 +37,7 @@ def flush(self): if __name__ == '__main__': - s = StreamTextWrapper(3,20) + s = StreamingTextWrapper(3,20) for x in sys.argv[1:]: y = s.add(' ' + x) print(y, end='', flush=True) diff --git a/src/chatdbg/ipdb_util/text.py b/src/chatdbg/ipdb_util/text.py index 538293d..bc9f7ef 100644 --- a/src/chatdbg/ipdb_util/text.py +++ b/src/chatdbg/ipdb_util/text.py @@ -3,7 +3,7 @@ import inspect import numbers import numpy as np - +import textwrap def make_arrow(pad): """generate the leading arrow in front of traceback or debugger""" @@ -116,3 +116,27 @@ def truncate_proportionally(text, maxlen=32000, top_proportion=0.5): post = max(0, maxlen - 3 - pre) return text[:pre] + "..." + text[len(text) - post :] return text + +def word_wrap_except_code_blocks(text: str, width: int = 80) -> str: + """ + Wraps text except for code blocks for nice terminal formatting. + + Splits the text into paragraphs and wraps each paragraph, + except for paragraphs that are inside of code blocks denoted + by ` ``` `. Returns the updated text. + + Args: + text (str): The text to wrap. + width (int): The width of the lines to wrap at, passed to `textwrap.fill`. + + Returns: + The wrapped text. + """ + blocks = text.split('```') + for i in range(len(blocks)): + if i % 2 == 0: + paras = blocks[i].split('\n') + wrapped = [ textwrap.fill(para, width=width) for para in paras] + blocks[i] = "\n".join(wrapped) + + return '```'.join(blocks) From 7d2a584becc9fa394d6ad4363d6fe03a6eb44df8 Mon Sep 17 00:00:00 2001 From: Stephen Freund Date: Tue, 26 Mar 2024 16:10:57 -0400 Subject: [PATCH 09/17] streaming working --- src/chatdbg/assistant/assistant.py | 217 ++++++++++++++++------------ src/chatdbg/chatdbg_pdb.py | 118 +++++++-------- src/chatdbg/ipdb_util/chatlog.py | 4 +- src/chatdbg/ipdb_util/streamwrap.py | 43 +++--- 4 files changed, 207 insertions(+), 175 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index f25f3f7..253d206 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -54,13 +54,13 @@ def fail(self, text): sys.exit(1) def begin_stream(self): - print("<<<", file=self.out) + pass def stream_delta(self, text): - print(text, end='', file=self.out) + print(text, end='', file=self.out, flush=True) def end_stream(self): - print(">>>", file=self.out) + pass def begin_query(self, prompt, **kwargs): pass @@ -80,7 +80,21 @@ def function_call(self, call, result): print(entry, file=self.out) +class StreamingAssistantClient(PrintingAssistantClient): + def __init__(self, out=sys.stdout): + super().__init__(out) + + def begin_stream(self): + print('', flush=True) + def stream_delta(self, text): + print(text, end='', file=self.out, flush=True) + + def end_stream(self): + print('', flush=True) + + def response(self, text): + pass class Assistant: def __init__( @@ -92,6 +106,7 @@ def __init__( functions=[], max_call_response_tokens=4096, debug=False, + stream_response=False ): if debug: log_file = open(f"chatdbg.log", "w") @@ -109,6 +124,7 @@ def __init__( self._timeout = timeout self._conversation = [{"role": "system", "content": instructions}] self._max_call_response_tokens = max_call_response_tokens + self._stream_response = stream_response self._check_model() self._broadcast('begin_dialog', instructions) @@ -158,19 +174,14 @@ def _sandwhich_tokens(self, text: str, max_tokens: int, top_proportion: float) - bot_len = int((1-top_proportion) * total_len) return litellm.decode(model, tokens[0:top_len]) + " [...] " + litellm.decode(model, tokens[-bot_len:]) - - def _add_function(self, function): """ Add a new function to the list of function tools. - The function should have the necessary json spec as its docstring, with - this format: - "schema": function schema - "format": format to print call + The function should have the necessary json spec as its docstring """ schema = json.loads(function.__doc__) - assert "name" in schema['schema'], "Bad JSON in docstring for function tool." - self._functions[schema['schema']["name"]] = { + assert "name" in schema, "Bad JSON in docstring for function tool." + self._functions[schema["name"]] = { "function": function, "schema": schema } @@ -180,8 +191,7 @@ def _make_call(self, tool_call) -> str: try: args = json.loads(tool_call.function.arguments) function = self._functions[name] - call = function["schema"]["format"].format_map(args) - result = function["function"](**args) + call, result = function["function"](**args) self._broadcast('function_call', call, result) except OSError as e: # function produced some error -- move this to client??? @@ -190,11 +200,20 @@ def _make_call(self, tool_call) -> str: result = f"Ill-formed function call: {e}" return result - def query( self, prompt: str, - stream = False, + **kwargs + ): + if self._stream_response: + return self._streamed_query(prompt=prompt, **kwargs) + else: + return self._batch_query(prompt=prompt, **kwargs) + + + def _batch_query( + self, + prompt: str, **kwargs ): start = time.time() @@ -207,19 +226,7 @@ def query( while True: self._conversation = litellm.utils.trim_messages(self._conversation, self._model) - if stream: - self._broadcast('begin_stream') - completion_stream = self.completion(stream) - chunks = [] - for chunk in completion_stream: - print(chunk) - self._broadcast('stream_delta', chunk.choices[0].delta.content or "") - chunks.append(chunk) - completion = litellm.stream_chunk_builder(chunks, messages=self._conversation) - print('---', completion) - self._broadcast('end_stream') - else: - completion = self.completion(stream) + completion = self.completion() cost += litellm.completion_cost(completion) @@ -249,46 +256,9 @@ def query( self._broadcast('fail', f"Internal Error: {e.__dict__}") sys.exit(1) - def completion(self, stream): - completion = litellm.completion( - model=self._model, - messages=self._conversation, - tools=[ - {"type": "function", "function": f["schema"]["schema"]} - for f in self._functions.values() - ], - timeout=self._timeout, - logger_fn=self._logger, - stream=stream - ) - - return completion - - def _add_function_results_to_conversation(self, response_message): - tool_calls = response_message.tool_calls - try: - for tool_call in tool_calls: - function_response = self._make_call(tool_call) - function_response = self._sandwhich_tokens( - function_response, - self._max_call_response_tokens, - 0.5) - response = { - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_call.function.name, - "content": function_response, - } - self._conversation.append(response) - self._broadcast('response', '') - except Exception as e: - # Warning: potential infinite loop. - self._broadcast('warn', f"Error processing tool calls: {e}") - - def query2( + def _streamed_query( self, prompt: str, - stream: bool = False, **kwargs ): start = time.time() @@ -300,35 +270,59 @@ def query2( while True: self._conversation = litellm.utils.trim_messages(self._conversation, self._model) - response = litellm.completion( - model=self._model, - messages=self._conversation, - tools=[ - {"type": "function", "function": f["schema"]["schema"]} - for f in self._functions.values() - ], - timeout=self._timeout, - logger_fn=self._logger, - stream=True - ) - - chunks=[] - for chunk in response: - print(chunk) + # print("\n".join([str(x) for x in self._conversation])) + + stream = self.completion(stream=True) + + # litellm is broken for new GPT models that have content before calls, so... + + # stream the response, collecting the tool_call parts separately from the content + self._broadcast('begin_stream') + chunks = [] + tool_chunks = [] + for chunk in stream: chunks.append(chunk) + if chunk.choices[0].delta.content != None: + self._broadcast('stream_delta', chunk.choices[0].delta.content) + else: + tool_chunks.append(chunk) + self._broadcast('end_stream') + + # compute for the part that litellm gives back. + completion = litellm.stream_chunk_builder(chunks, messages=self._conversation) + cost += litellm.completion_cost(completion) - print('---\n',litellm.stream_chunk_builder(chunks, messages=self._conversation)) + # add content to conversation, but if there is no content, then the message + # has only tool calls, and skip this step + response_message = completion.choices[0].message + if response_message.content != None: + self._conversation.append(response_message) + + if response_message.content != None: + self._broadcast('response', '(Message) ' + response_message.content) + + if completion.choices[0].finish_reason == 'tool_calls': + # create a message with just the tool calls, append that to the conversation, and generate the responses. + tool_completion = litellm.stream_chunk_builder(tool_chunks,self._conversation) + + # this part wasn't counted above... + cost += litellm.completion_cost(tool_completion) + + tool_message = tool_completion.choices[0].message + cost += litellm.completion_cost(tool_completion) + self._conversation.append(tool_message) + self._add_function_results_to_conversation(tool_message) + else: + break - break - elapsed = time.time() - start stats = { "cost": cost, "time": elapsed, "model": self._model, - # "tokens": completion.usage.total_tokens, - # "prompt_tokens": completion.usage.prompt_tokens, - # "completion_tokens": completion.usage.completion_tokens, + "tokens": completion.usage.total_tokens, + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, } self._broadcast("end_query", stats) return stats @@ -337,11 +331,44 @@ def query2( sys.exit(1) + def completion(self, stream=False): + return litellm.completion( + model=self._model, + messages=self._conversation, + tools=[ + {"type": "function", "function": f["schema"]} + for f in self._functions.values() + ], + timeout=self._timeout, + logger_fn=self._logger, + stream=stream + ) + + def _add_function_results_to_conversation(self, response_message): + response_message['role'] = 'assistant' + tool_calls = response_message.tool_calls + try: + for tool_call in tool_calls: + function_response = self._make_call(tool_call) + function_response = self._sandwhich_tokens( + function_response, + self._max_call_response_tokens, + 0.5) + response = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": function_response, + } + self._conversation.append(response) + except Exception as e: + # Warning: potential infinite loop. + self._broadcast('warn', f"Error processing tool calls: {e}") + if __name__ == '__main__': - def weather(location): + def weather(location,unit='f'): """ { - "schema":{ "name": "get_weather", "description": "Determine weather in my location", "parameters": { @@ -363,15 +390,13 @@ def weather(location): "location" ] } - }, - "format": "(ChatDBG) weather in {location}" } """ - return "Sunny and 72 degrees." + return f"weather(location, unit)", "Sunny and 72 degrees." + + - a = Assistant("You generate text.") - a._add_function(weather) - # x = a.query("tell me what model you are before making any function calls. And what's the weather in Boston?", stream=True) - x = a.query("What's the weather in Boston?", stream=True) + a = Assistant("You generate text.", clients=[ StreamingAssistantClient() ], functions=[weather]) + x = a.query("tell me what model you are before making any function calls. And what's the weather in Boston?", stream=True) print(x) diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 86a5ebc..0c21014 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -598,45 +598,39 @@ def _make_assistant(self): def info(value): """ { - "schema": { - "name": "info", - "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", - "parameters": { - "type": "object", - "properties": { - "value": { - "type": "string", - "description": "The reference to get the information for." - } - }, - "required": [ "value" ] - } - }, - "format": "info {value}" + "name": "info", + "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", + "parameters": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The reference to get the information for." + } + }, + "required": [ "value" ] + } } """ command = f"info {value}" result = self._capture_onecmd(command) - return truncate_proportionally(result, top_proportion=1) + return command, truncate_proportionally(result, top_proportion=1) def debug(command): """ { - "schema" : { - "name": "debug", - "description": "Run a pdb command and get the response.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The pdb command to run." - } - }, - "required": [ "command" ] - } - }, - "format": "{command}" + "name": "debug", + "description": "Run a pdb command and get the response.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The pdb command to run." + } + }, + "required": [ "command" ] + } } """ cmd = command if command != "list" else "ll" @@ -647,31 +641,28 @@ def debug(command): # if old_curframe != self.curframe: # result += strip_color(self._stack_prompt()) - return truncate_proportionally(result, maxlen=8000, top_proportion=0.9) + return command, truncate_proportionally(result, maxlen=8000, top_proportion=0.9) def slice(name): """ { - "schema": { - "name": "slice", - "description": "Return the code to compute a global variable used in the current frame", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The variable to look at." - } - }, - "required": [ "name" ] - } - }, - "format": "slice {name}" + "name": "slice", + "description": "Return the code to compute a global variable used in the current frame", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The variable to look at." + } + }, + "required": [ "name" ] + } } """ command = f"slice {name}" result = self._capture_onecmd(command) - return truncate_proportionally(result, top_proportion=0.5) + return command, truncate_proportionally(result, top_proportion=0.5) instruction_prompt = self._ip_instructions() @@ -687,10 +678,12 @@ def slice(name): model=chatdbg_config.model, debug=chatdbg_config.debug, functions=functions, + stream_response=chatdbg_config.stream_response, clients=[ ChatAssistantClient(self.stdout, self.prompt, self._chat_prefix, - self._text_width), + self._text_width, + stream=chatdbg_config.stream_response), self._log ] ) @@ -701,23 +694,24 @@ def slice(name): class ChatAssistantClient(AbstractAssistantClient): - def __init__(self, out, debugger_prompt, chat_prefix, width): + def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): self.out = out self.debugger_prompt = debugger_prompt self.chat_prefix = chat_prefix self.width = width self._assistant = None + self._stream = stream # Call backs - def begin_query(self, user_text, full_prompt): + def begin_query(self, prompt='', user_text=''): pass def end_query(self, stats): pass - def _print(self, text): - print(textwrap.indent(text, self.chat_prefix, lambda _: True), file=self.out) + def _print(self, text, **kwargs): + print(textwrap.indent(text, self.chat_prefix, lambda _: True), file=self.out, **kwargs) def warn(self, text): self._print(textwrap.indent(text, '*** ')) @@ -726,12 +720,22 @@ def fail(self, text): self._print(textwrap.indent(text, '*** ')) sys.exit(1) - def stream(self, event, text): - # begin / none, step / delta , complete / full - pass + def begin_stream(self): + self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) + self._at_start = True + # print(self._stream_wrapper.append("\n", False), end='', flush=True, file=self.out) + + def stream_delta(self, text): + if self._at_start: + self._at_start = False + print(self._stream_wrapper.append("\n(Message) ", False), end='', flush=True, file=self.out) + print(self._stream_wrapper.append(text, False), end='', flush=True, file=self.out) + + def end_stream(self): + print(self._stream_wrapper.flush(), end='', flush=True, file=self.out) def response(self, text): - if text != None: + if not self._stream and text != None: text = word_wrap_except_code_blocks(text, self.width-len(self.chat_prefix)) self._print(text) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index 8a169f9..7b2c795 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -97,12 +97,12 @@ def end_dialog(self): self._dump() self._log = None - def begin_query(self, prompt, kwargs): + def begin_query(self, prompt, user_text): log = self._log assert log != None assert self._current_chat == None self._current_chat = { - 'input':kwargs['user_text'], + 'input':user_text, 'prompt':prompt, 'output': { 'type': 'chat', 'outputs': []} } diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py index 3eb74bd..f97f481 100644 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -7,31 +7,34 @@ class StreamingTextWrapper: def __init__(self, indent=' ', width=70): - self.buffer = '' # the raw text so far - self.wrapped = '' # the successfully wrapped text do far - self.pending = '' # the part after the last space in buffer -- has not been wrapped yet - self.indent = indent - self.width = width + self._buffer = '' # the raw text so far + self._wrapped = '' # the successfully wrapped text do far + self._pending = '' # the part after the last space in buffer -- has not been wrapped yet + self._indent = indent + self._width = width - def add(self, text, flush=False): + def append(self, text, flush=False): if flush: - self.buffer += self.pending + text - self.pending = '' + self._buffer += self._pending + text + self._pending = '' else: - text_bits = re.split(r'(\s+)', self.pending + text) - self.pending = text_bits[-1] - self.buffer += (''.join(text_bits[0:-1])) - - wrapped = word_wrap_except_code_blocks(self.buffer, self.width) - wrapped = textwrap.indent(wrapped, self.indent, lambda _: True) - wrapped_delta = wrapped[len(self.wrapped):] - self.wrapped = wrapped + text_bits = re.split(r'(\s+)', self._pending + text) + self._pending = text_bits[-1] + self._buffer += (''.join(text_bits[0:-1])) + + wrapped = word_wrap_except_code_blocks(self._buffer, self._width) + wrapped = textwrap.indent(wrapped, self._indent, lambda _: True) + wrapped_delta = wrapped[len(self._wrapped):] + self._wrapped = wrapped return wrapped_delta def flush(self): - result = self.add('', flush=True) - self.buffer = '' - self.wrapped = '' + if len(self._buffer) > 0: + result = self.append('\n', flush=True) + else: + result = self.append('', flush=True) + self._buffer = '' + self._wrapped = '' return result @@ -39,6 +42,6 @@ def flush(self): if __name__ == '__main__': s = StreamingTextWrapper(3,20) for x in sys.argv[1:]: - y = s.add(' ' + x) + y = s.append(' ' + x) print(y, end='', flush=True) print(s.flush()) From 5aedd9205cad9ecad5d2b7791805014cf4d10927 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Tue, 26 Mar 2024 20:29:43 -0400 Subject: [PATCH 10/17] config for chatdbg on command line, renew command --- src/chatdbg/__main__.py | 49 ++++++++++++++++++++++++++++- src/chatdbg/chatdbg_pdb.py | 47 ++++++++++++++++----------- src/chatdbg/ipdb_util/chatlog.py | 11 ++++--- src/chatdbg/ipdb_util/config.py | 23 +++++++++----- src/chatdbg/ipdb_util/streamwrap.py | 6 ++-- 5 files changed, 102 insertions(+), 34 deletions(-) diff --git a/src/chatdbg/__main__.py b/src/chatdbg/__main__.py index b070763..059b12c 100644 --- a/src/chatdbg/__main__.py +++ b/src/chatdbg/__main__.py @@ -1,7 +1,54 @@ -from .chatdbg_pdb import * import ipdb +from chatdbg.chatdbg_pdb import ChatDBG +from chatdbg.ipdb_util.config import chatdbg_config +import sys +import getopt +_usage = """\ +usage: python -m ipdb [-m] [-c command] ... pyfile [arg] ... + +Debug the Python program given by pyfile. + +Initial commands are read from .pdbrc files in your home directory +and in the current directory, if they exist. Commands supplied with +-c are executed after commands from .pdbrc files. + +To let the script run until an exception occurs, use "-c continue". +To let the script run up to a given line X in the debugged file, use +"-c 'until X'" + +Option -m is available only in Python 3.7 and later. + +ChatDBG-specific options may appear anywhere before pyfile: + -debug dump the LLM messages to a chatdbg.log + -log file where to write the log of the debugging session + -model model the LLM model to use. + -stream stream responses from the LLM + +""" def main(): ipdb.__main__._get_debugger_cls = lambda: ChatDBG + + opts, args = getopt.getopt(sys.argv[1:], "mhc:", ['help', 'debug', 'log=', 'model=', 'stream', 'command=' ]) + pdb_args = [ sys.argv[0] ] + for opt, optarg in opts: + if opt in ['-h', '--help']: + print(_usage) + sys.exit() + elif opt in ['--debug']: + chatdbg_config.debug = True + elif opt in ['--stream']: + chatdbg_config.stream = True + elif opt in ['--model']: + chatdbg_config.model = optarg + elif opt in ['--log']: + chatdbg_config.model = optarg + elif opt in ['-c', '--command']: + pdb_args += [ opt, optarg ] + elif opt in ['-m']: + pdb_args = [ opt ] + + sys.argv = pdb_args + args + ipdb.__main__.main() diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 0c21014..698be86 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -19,18 +19,17 @@ from .assistant.assistant import Assistant, AbstractAssistantClient from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper -from .ipdb_util.config import Chat +from .ipdb_util.config import Chat, chatdbg_config from .ipdb_util.locals import * from .ipdb_util.prompts import pdb_instructions from .ipdb_util.streamwrap import StreamingTextWrapper from .ipdb_util.text import * -chatdbg_config: Chat = None - def load_ipython_extension(ipython): global chatdbg_config - from chatdbg.chatdbg_pdb import Chat, ChatDBG + from chatdbg.chatdbg_pdb import ChatDBG + from chatdbg.ipdb_util.config import Chat, chatdbg_config ipython.InteractiveTB.debugger_cls = ChatDBG chatdbg_config = Chat(config=ipython.config) @@ -78,10 +77,6 @@ def __init__(self, *args, **kwargs): self._history = [] self._error_specific_prompt = "" - global chatdbg_config - if chatdbg_config == None: - chatdbg_config = Chat() - sys.stdin = CaptureInput(sys.stdin) # Only use flow when we are in jupyter or using stdin in ipython. In both @@ -204,18 +199,17 @@ def onecmd(self, line: str) -> bool: hist_file = CopyingTextIOWrapper(self.stdout) self.stdout = hist_file try: - self.was_chat = False + self.was_chat_or_renew = False return super().onecmd(line) finally: self.stdout = hist_file.getfile() output = strip_color(hist_file.getvalue()) - if not self.was_chat: + if not self.was_chat_or_renew: self._log.function_call(line, output) - if ( - line.split(' ')[0] not in ["hist", "test_prompt", "c", "cont", "continue", "config"] - and not self.was_chat - ): - self._history += [(line, output)] + if ( + line.split(' ')[0] not in ["hist", "test_prompt", "c", "cont", "continue", "config" ] + ): + self._history += [(line, output)] def message(self, msg) -> None: """ @@ -332,6 +326,11 @@ def do_info(self, arg): ) def do_slice(self, arg): + """ + slice + Print the backwards slice for a variable used in the current cell but + defined in an earlier cell. [interactive IPython / Jupyter only] + """ if not self._supports_flow: self.message("*** `slice` is only supported in Jupyter notebooks") return @@ -558,10 +557,10 @@ def _build_prompt(self, arg, conversing): arg) def do_chat(self, arg): - """chat/: + """chat Send a chat message. """ - self.was_chat = True + self.was_chat_or_renew = True full_prompt = self._build_prompt(arg, self._assistant != None) full_prompt = strip_color(full_prompt) @@ -576,6 +575,16 @@ def do_chat(self, arg): self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") + def do_renew(self, arg): + """renew + End the current chat dialog and prepare to start a new one. + """ + if self._assistant != None: + self._assistant.close() + self._assistant = None + self.was_chat_or_renew = True + self.message(f"Ready to start a new dialog.") + def do_config(self, arg): args = arg.split() if len(args) == 0: @@ -678,12 +687,12 @@ def slice(name): model=chatdbg_config.model, debug=chatdbg_config.debug, functions=functions, - stream_response=chatdbg_config.stream_response, + stream_response=chatdbg_config.stream, clients=[ ChatAssistantClient(self.stdout, self.prompt, self._chat_prefix, self._text_width, - stream=chatdbg_config.stream_response), + stream=chatdbg_config.stream), self._log ] ) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index 7b2c795..ab697d5 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -35,6 +35,7 @@ def __getattr__(self, attr): class ChatDBGLog(AbstractAssistantClient): + def __init__(self, log_filename, config, capture_streams=True): self._log_filename = log_filename self.config = config @@ -47,21 +48,23 @@ def __init__(self, log_filename, config, capture_streams=True): self._stderr_wrapper = None self._stderr_wrapper = None + self._log = self._make_log() + self._current_chat = None + + def _make_log(self): meta = { 'time': datetime.now(), 'command_line': " ".join(sys.argv), 'uid': str(uuid.uuid4()), 'config': self.config } - log = { + return { 'steps':[], 'meta':meta, 'instructions':None, 'stdout':self._stdout_wrapper.getvalue(), 'stderr':self._stderr_wrapper.getvalue(), } - self._current_chat = None - self._log = log def _dump(self): log = self._log @@ -95,7 +98,7 @@ def begin_dialog(self, instructions): def end_dialog(self): if self._log != None: self._dump() - self._log = None + self._log = self._make_log() def begin_query(self, prompt, user_text): log = self._log diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/ipdb_util/config.py index ae845ca..b3a345b 100644 --- a/src/chatdbg/ipdb_util/config.py +++ b/src/chatdbg/ipdb_util/config.py @@ -17,11 +17,13 @@ def chat_get_env(option_name, default_value): class Chat(Configurable): model = Unicode( - chat_get_env("model", "gpt-4-1106-preview"), help="The OpenAI model" + chat_get_env("model", "gpt-4-1106-preview"), help="The LLM model" ).tag(config=True) - # model = Unicode(default_value='gpt-3.5-turbo-1106', help="The OpenAI model").tag(config=True) - debug = Bool(chat_get_env("debug", False), help="Log OpenAI calls").tag(config=True) + + debug = Bool(chat_get_env("debug", False), help="Log LLM calls").tag(config=True) + log = Unicode(chat_get_env("log", "log.yaml"), help="The log file").tag(config=True) + tag = Unicode(chat_get_env("tag", ""), help="Any extra info for log file").tag( config=True ) @@ -30,23 +32,28 @@ class Chat(Configurable): ).tag(config=True) context = Int( - chat_get_env("context", 5), + chat_get_env("context", 10), help="lines of source code to show when displaying stacktrace information", ).tag(config=True) + show_locals = Bool( chat_get_env("show_locals", True), help="show local var values in stacktrace" ).tag(config=True) + show_libs = Bool( chat_get_env("show_libs", False), help="show library frames in stacktrace" ).tag(config=True) + show_slices = Bool( chat_get_env("show_slices", True), help="support the `slice` command" ).tag(config=True) + take_the_wheel = Bool( chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" ).tag(config=True) - stream_response = Bool( - chat_get_env("stream_response", True), help="Stream the response at it arrives" + + stream = Bool( + chat_get_env("stream", False), help="Stream the response at it arrives" ).tag(config=True) def to_json(self): @@ -62,5 +69,7 @@ def to_json(self): "show_libs": self.show_libs, "show_slices": self.show_slices, "take_the_wheel": self.take_the_wheel, - "stream_response": self.stream_response, + "stream": self.stream, } + +chatdbg_config: Chat = Chat() diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py index f97f481..601fa5f 100644 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -6,12 +6,12 @@ class StreamingTextWrapper: - def __init__(self, indent=' ', width=70): + def __init__(self, indent=' ', width=80): self._buffer = '' # the raw text so far self._wrapped = '' # the successfully wrapped text do far self._pending = '' # the part after the last space in buffer -- has not been wrapped yet self._indent = indent - self._width = width + self._width = width - len(indent) def append(self, text, flush=False): if flush: @@ -22,7 +22,7 @@ def append(self, text, flush=False): self._pending = text_bits[-1] self._buffer += (''.join(text_bits[0:-1])) - wrapped = word_wrap_except_code_blocks(self._buffer, self._width) + wrapped = word_wrap_except_code_blocks(self._buffer) wrapped = textwrap.indent(wrapped, self._indent, lambda _: True) wrapped_delta = wrapped[len(self._wrapped):] self._wrapped = wrapped From 481d8e8f766290c2a9adab02e2ffd5bc56e0eef5 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Tue, 26 Mar 2024 20:32:34 -0400 Subject: [PATCH 11/17] blackened --- src/chatdbg/__main__.py | 25 ++-- src/chatdbg/assistant/assistant.py | 183 ++++++++++++++++------------ src/chatdbg/chatdbg_pdb.py | 113 ++++++++++------- src/chatdbg/ipdb_util/chatlog.py | 76 ++++++++---- src/chatdbg/ipdb_util/config.py | 11 +- src/chatdbg/ipdb_util/printer.py | 14 ++- src/chatdbg/ipdb_util/streamwrap.py | 35 +++--- src/chatdbg/ipdb_util/text.py | 10 +- 8 files changed, 271 insertions(+), 196 deletions(-) diff --git a/src/chatdbg/__main__.py b/src/chatdbg/__main__.py index 059b12c..743dc6d 100644 --- a/src/chatdbg/__main__.py +++ b/src/chatdbg/__main__.py @@ -27,27 +27,30 @@ """ + def main(): ipdb.__main__._get_debugger_cls = lambda: ChatDBG - opts, args = getopt.getopt(sys.argv[1:], "mhc:", ['help', 'debug', 'log=', 'model=', 'stream', 'command=' ]) - pdb_args = [ sys.argv[0] ] + opts, args = getopt.getopt( + sys.argv[1:], "mhc:", ["help", "debug", "log=", "model=", "stream", "command="] + ) + pdb_args = [sys.argv[0]] for opt, optarg in opts: - if opt in ['-h', '--help']: + if opt in ["-h", "--help"]: print(_usage) sys.exit() - elif opt in ['--debug']: + elif opt in ["--debug"]: chatdbg_config.debug = True - elif opt in ['--stream']: + elif opt in ["--stream"]: chatdbg_config.stream = True - elif opt in ['--model']: + elif opt in ["--model"]: chatdbg_config.model = optarg - elif opt in ['--log']: + elif opt in ["--log"]: chatdbg_config.model = optarg - elif opt in ['-c', '--command']: - pdb_args += [ opt, optarg ] - elif opt in ['-m']: - pdb_args = [ opt ] + elif opt in ["-c", "--command"]: + pdb_args += [opt, optarg] + elif opt in ["-m"]: + pdb_args = [opt] sys.argv = pdb_args + args diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 253d206..0c688d2 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -8,6 +8,7 @@ import textwrap import sys + class AbstractAssistantClient: def begin_dialog(self, instructions): @@ -30,8 +31,10 @@ def fail(self, text): def begin_stream(self): pass + def stream_delta(self, text): pass + def end_stream(self): pass @@ -47,17 +50,17 @@ def __init__(self, out=sys.stdout): self.out = out def warn(self, text): - print(textwrap.indent(text, '*** '), file=self.out) + print(textwrap.indent(text, "*** "), file=self.out) def fail(self, text): - print(textwrap.indent(text, '*** '), file=self.out) + print(textwrap.indent(text, "*** "), file=self.out) sys.exit(1) def begin_stream(self): pass def stream_delta(self, text): - print(text, end='', file=self.out, flush=True) + print(text, end="", file=self.out, flush=True) def end_stream(self): pass @@ -71,51 +74,54 @@ def end_query(self, stats): def response(self, text): if text != None: print(text, file=self.out) - + def function_call(self, call, result): if result and len(result) > 0: entry = f"{call}\n{result}" else: entry = f"{call}" print(entry, file=self.out) - + class StreamingAssistantClient(PrintingAssistantClient): def __init__(self, out=sys.stdout): super().__init__(out) def begin_stream(self): - print('', flush=True) + print("", flush=True) def stream_delta(self, text): - print(text, end='', file=self.out, flush=True) + print(text, end="", file=self.out, flush=True) def end_stream(self): - print('', flush=True) + print("", flush=True) def response(self, text): pass + class Assistant: def __init__( self, instructions, model="gpt-3.5-turbo-1106", timeout=30, - clients = [ PrintingAssistantClient() ], + clients=[PrintingAssistantClient()], functions=[], max_call_response_tokens=4096, debug=False, - stream_response=False + stream_response=False, ): if debug: log_file = open(f"chatdbg.log", "w") - self._logger = lambda model_call_dict: print(model_call_dict, file=log_file, flush=True) + self._logger = lambda model_call_dict: print( + model_call_dict, file=log_file, flush=True + ) else: self._logger = None self._clients = clients - + self._functions = {} for f in functions: self._add_function(f) @@ -127,10 +133,10 @@ def __init__( self._stream_response = stream_response self._check_model() - self._broadcast('begin_dialog', instructions) + self._broadcast("begin_dialog", instructions) def close(self): - self._broadcast('end_dialog') + self._broadcast("end_dialog") def _broadcast(self, method_name, *args, **kwargs): for client in self._clients: @@ -143,25 +149,42 @@ def _check_model(self): missing_keys = result["missing_keys"] if missing_keys != []: _, provider, _, _ = litellm.get_llm_provider(self._model) - if provider == 'openai': - self._broadcast('fail', textwrap.dedent(f"""\ + if provider == "openai": + self._broadcast( + "fail", + textwrap.dedent( + f"""\ You need an OpenAI key to use the {self._model} model. You can get a key here: https://platform.openai.com/api-keys. - Set the environment variable OPENAI_API_KEY to your key value.""")) + Set the environment variable OPENAI_API_KEY to your key value.""" + ), + ) sys.exit(1) else: - self._broadcast('fail', textwrap.dedent(f"""\ + self._broadcast( + "fail", + textwrap.dedent( + f"""\ You need to set the following environment variables - to use the {self._model} model: {', '.join(missing_keys)}""")) + to use the {self._model} model: {', '.join(missing_keys)}""" + ), + ) sys.exit(1) if not litellm.supports_function_calling(self._model): - self._broadcast('fail', textwrap.dedent(f"""\ + self._broadcast( + "fail", + textwrap.dedent( + f"""\ The {self._model} model does not support function calls. - You must use a model that does, eg. gpt-4.""")) + You must use a model that does, eg. gpt-4.""" + ), + ) sys.exit(1) - def _sandwhich_tokens(self, text: str, max_tokens: int, top_proportion: float) -> str: + def _sandwhich_tokens( + self, text: str, max_tokens: int, top_proportion: float + ) -> str: model = self._model if max_tokens == None: return text @@ -169,10 +192,14 @@ def _sandwhich_tokens(self, text: str, max_tokens: int, top_proportion: float) - if len(tokens) <= max_tokens: return text else: - total_len = max_tokens - 5 # some slop for the ... + total_len = max_tokens - 5 # some slop for the ... top_len = int(top_proportion * total_len) - bot_len = int((1-top_proportion) * total_len) - return litellm.decode(model, tokens[0:top_len]) + " [...] " + litellm.decode(model, tokens[-bot_len:]) + bot_len = int((1 - top_proportion) * total_len) + return ( + litellm.decode(model, tokens[0:top_len]) + + " [...] " + + litellm.decode(model, tokens[-bot_len:]) + ) def _add_function(self, function): """ @@ -181,10 +208,7 @@ def _add_function(self, function): """ schema = json.loads(function.__doc__) assert "name" in schema, "Bad JSON in docstring for function tool." - self._functions[schema["name"]] = { - "function": function, - "schema": schema - } + self._functions[schema["name"]] = {"function": function, "schema": schema} def _make_call(self, tool_call) -> str: name = tool_call.function.name @@ -192,7 +216,7 @@ def _make_call(self, tool_call) -> str: args = json.loads(tool_call.function.arguments) function = self._functions[name] call, result = function["function"](**args) - self._broadcast('function_call', call, result) + self._broadcast("function_call", call, result) except OSError as e: # function produced some error -- move this to client??? result = f"Error: {e}" @@ -200,22 +224,13 @@ def _make_call(self, tool_call) -> str: result = f"Ill-formed function call: {e}" return result - def query( - self, - prompt: str, - **kwargs - ): + def query(self, prompt: str, **kwargs): if self._stream_response: return self._streamed_query(prompt=prompt, **kwargs) else: return self._batch_query(prompt=prompt, **kwargs) - - def _batch_query( - self, - prompt: str, - **kwargs - ): + def _batch_query(self, prompt: str, **kwargs): start = time.time() cost = 0 @@ -224,7 +239,9 @@ def _batch_query( self._conversation.append({"role": "user", "content": prompt}) while True: - self._conversation = litellm.utils.trim_messages(self._conversation, self._model) + self._conversation = litellm.utils.trim_messages( + self._conversation, self._model + ) completion = self.completion() @@ -232,11 +249,11 @@ def _batch_query( response_message = completion.choices[0].message self._conversation.append(response_message) - + if response_message.content: - self._broadcast('response', '(Message) ' + response_message.content) + self._broadcast("response", "(Message) " + response_message.content) - if completion.choices[0].finish_reason == 'tool_calls': + if completion.choices[0].finish_reason == "tool_calls": self._add_function_results_to_conversation(response_message) else: break @@ -253,14 +270,10 @@ def _batch_query( self._broadcast("end_query", stats) return stats except openai.OpenAIError as e: - self._broadcast('fail', f"Internal Error: {e.__dict__}") + self._broadcast("fail", f"Internal Error: {e.__dict__}") sys.exit(1) - def _streamed_query( - self, - prompt: str, - **kwargs - ): + def _streamed_query(self, prompt: str, **kwargs): start = time.time() cost = 0 @@ -269,7 +282,9 @@ def _streamed_query( self._conversation.append({"role": "user", "content": prompt}) while True: - self._conversation = litellm.utils.trim_messages(self._conversation, self._model) + self._conversation = litellm.utils.trim_messages( + self._conversation, self._model + ) # print("\n".join([str(x) for x in self._conversation])) stream = self.completion(stream=True) @@ -277,19 +292,21 @@ def _streamed_query( # litellm is broken for new GPT models that have content before calls, so... # stream the response, collecting the tool_call parts separately from the content - self._broadcast('begin_stream') + self._broadcast("begin_stream") chunks = [] tool_chunks = [] for chunk in stream: chunks.append(chunk) if chunk.choices[0].delta.content != None: - self._broadcast('stream_delta', chunk.choices[0].delta.content) + self._broadcast("stream_delta", chunk.choices[0].delta.content) else: tool_chunks.append(chunk) - self._broadcast('end_stream') + self._broadcast("end_stream") # compute for the part that litellm gives back. - completion = litellm.stream_chunk_builder(chunks, messages=self._conversation) + completion = litellm.stream_chunk_builder( + chunks, messages=self._conversation + ) cost += litellm.completion_cost(completion) # add content to conversation, but if there is no content, then the message @@ -297,13 +314,15 @@ def _streamed_query( response_message = completion.choices[0].message if response_message.content != None: self._conversation.append(response_message) - + if response_message.content != None: - self._broadcast('response', '(Message) ' + response_message.content) + self._broadcast("response", "(Message) " + response_message.content) - if completion.choices[0].finish_reason == 'tool_calls': + if completion.choices[0].finish_reason == "tool_calls": # create a message with just the tool calls, append that to the conversation, and generate the responses. - tool_completion = litellm.stream_chunk_builder(tool_chunks,self._conversation) + tool_completion = litellm.stream_chunk_builder( + tool_chunks, self._conversation + ) # this part wasn't counted above... cost += litellm.completion_cost(tool_completion) @@ -327,10 +346,9 @@ def _streamed_query( self._broadcast("end_query", stats) return stats except openai.OpenAIError as e: - self._broadcast('fail', f"Internal Error: {e.__dict__}") + self._broadcast("fail", f"Internal Error: {e.__dict__}") sys.exit(1) - def completion(self, stream=False): return litellm.completion( model=self._model, @@ -341,32 +359,33 @@ def completion(self, stream=False): ], timeout=self._timeout, logger_fn=self._logger, - stream=stream + stream=stream, ) def _add_function_results_to_conversation(self, response_message): - response_message['role'] = 'assistant' + response_message["role"] = "assistant" tool_calls = response_message.tool_calls try: for tool_call in tool_calls: function_response = self._make_call(tool_call) function_response = self._sandwhich_tokens( - function_response, - self._max_call_response_tokens, - 0.5) + function_response, self._max_call_response_tokens, 0.5 + ) response = { - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_call.function.name, - "content": function_response, - } + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": function_response, + } self._conversation.append(response) except Exception as e: - # Warning: potential infinite loop. - self._broadcast('warn', f"Error processing tool calls: {e}") + # Warning: potential infinite loop. + self._broadcast("warn", f"Error processing tool calls: {e}") + + +if __name__ == "__main__": -if __name__ == '__main__': - def weather(location,unit='f'): + def weather(location, unit="f"): """ { "name": "get_weather", @@ -394,9 +413,11 @@ def weather(location,unit='f'): """ return f"weather(location, unit)", "Sunny and 72 degrees." - - - a = Assistant("You generate text.", clients=[ StreamingAssistantClient() ], functions=[weather]) - x = a.query("tell me what model you are before making any function calls. And what's the weather in Boston?", stream=True) + a = Assistant( + "You generate text.", clients=[StreamingAssistantClient()], functions=[weather] + ) + x = a.query( + "tell me what model you are before making any function calls. And what's the weather in Boston?", + stream=True, + ) print(x) - diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 698be86..f8be2d0 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -100,9 +100,11 @@ def __init__(self, *args, **kwargs): # set this to True ONLY AFTER we have had access to stack frames self._show_locals = False - self._log = ChatDBGLog(log_filename=chatdbg_config.log, - config=chatdbg_config.to_json(), - capture_streams=True) + self._log = ChatDBGLog( + log_filename=chatdbg_config.log, + config=chatdbg_config.to_json(), + capture_streams=True, + ) def _close_assistant(self): if self._assistant != None: @@ -206,9 +208,14 @@ def onecmd(self, line: str) -> bool: output = strip_color(hist_file.getvalue()) if not self.was_chat_or_renew: self._log.function_call(line, output) - if ( - line.split(' ')[0] not in ["hist", "test_prompt", "c", "cont", "continue", "config" ] - ): + if line.split(" ")[0] not in [ + "hist", + "test_prompt", + "c", + "cont", + "continue", + "config", + ]: self._history += [(line, output)] def message(self, msg) -> None: @@ -216,7 +223,7 @@ def message(self, msg) -> None: Override to remove tabs for messages so we can indent them. """ return super().message(str(msg).expandtabs()) - + def error(self, msg) -> None: """ Override to remove tabs for messages so we can indent them. @@ -514,17 +521,16 @@ def _stack_prompt(self): self.stdout = stdout def _ip_instructions(self): - return pdb_instructions( - self._supports_flow, chatdbg_config.take_the_wheel - ) + return pdb_instructions(self._supports_flow, chatdbg_config.take_the_wheel) + def _ip_enchriched_stack_trace(self): return f"The program has this stack trace:\n```\n{self.format_stack_trace()}\n```\n" - + def _ip_error(self): return self._error_specific_prompt def _ip_inputs(self): - inputs = '' + inputs = "" if len(sys.argv) > 1: inputs += f"\nThese were the command line options:\n```\n{' '.join(sys.argv)}\n```\n" input = sys.stdin.get_captured_input() @@ -538,7 +544,7 @@ def _ip_history(self): hist = f"\nThis is the history of some pdb commands I ran and the results.\n```\n{hist}\n```\n" return hist else: - return '' + return "" def concat_prompt(self, *args): args = [a for a in args if len(a) > 0] @@ -546,15 +552,15 @@ def concat_prompt(self, *args): def _build_prompt(self, arg, conversing): if not conversing: - return self.concat_prompt( self._ip_enchriched_stack_trace(), - self._ip_inputs(), - self._ip_error(), - self._ip_history(), - arg) + return self.concat_prompt( + self._ip_enchriched_stack_trace(), + self._ip_inputs(), + self._ip_error(), + self._ip_history(), + arg, + ) else: - return self.concat_prompt(self._ip_history(), - self._stack_prompt(), - arg) + return self.concat_prompt(self._ip_history(), self._stack_prompt(), arg) def do_chat(self, arg): """chat @@ -650,7 +656,9 @@ def debug(command): # if old_curframe != self.curframe: # result += strip_color(self._stack_prompt()) - return command, truncate_proportionally(result, maxlen=8000, top_proportion=0.9) + return command, truncate_proportionally( + result, maxlen=8000, top_proportion=0.9 + ) def slice(name): """ @@ -676,9 +684,9 @@ def slice(name): instruction_prompt = self._ip_instructions() if chatdbg_config.take_the_wheel: - functions = [ debug, info ] + functions = [debug, info] if self._supports_flow: - functions += [ slice ] + functions += [slice] else: functions = [] @@ -688,20 +696,21 @@ def slice(name): debug=chatdbg_config.debug, functions=functions, stream_response=chatdbg_config.stream, - clients=[ ChatAssistantClient(self.stdout, - self.prompt, - self._chat_prefix, - self._text_width, - stream=chatdbg_config.stream), - self._log ] + clients=[ + ChatAssistantClient( + self.stdout, + self.prompt, + self._chat_prefix, + self._text_width, + stream=chatdbg_config.stream, + ), + self._log, + ], ) - - ############################################################### - class ChatAssistantClient(AbstractAssistantClient): def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): self.out = out @@ -710,48 +719,60 @@ def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): self.width = width self._assistant = None self._stream = stream - + # Call backs - def begin_query(self, prompt='', user_text=''): + def begin_query(self, prompt="", user_text=""): pass def end_query(self, stats): pass def _print(self, text, **kwargs): - print(textwrap.indent(text, self.chat_prefix, lambda _: True), file=self.out, **kwargs) + print( + textwrap.indent(text, self.chat_prefix, lambda _: True), + file=self.out, + **kwargs, + ) def warn(self, text): - self._print(textwrap.indent(text, '*** ')) + self._print(textwrap.indent(text, "*** ")) def fail(self, text): - self._print(textwrap.indent(text, '*** ')) + self._print(textwrap.indent(text, "*** ")) sys.exit(1) def begin_stream(self): self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) self._at_start = True # print(self._stream_wrapper.append("\n", False), end='', flush=True, file=self.out) - + def stream_delta(self, text): if self._at_start: self._at_start = False - print(self._stream_wrapper.append("\n(Message) ", False), end='', flush=True, file=self.out) - print(self._stream_wrapper.append(text, False), end='', flush=True, file=self.out) - + print( + self._stream_wrapper.append("\n(Message) ", False), + end="", + flush=True, + file=self.out, + ) + print( + self._stream_wrapper.append(text, False), end="", flush=True, file=self.out + ) + def end_stream(self): - print(self._stream_wrapper.flush(), end='', flush=True, file=self.out) + print(self._stream_wrapper.flush(), end="", flush=True, file=self.out) def response(self, text): if not self._stream and text != None: - text = word_wrap_except_code_blocks(text, self.width-len(self.chat_prefix)) + text = word_wrap_except_code_blocks( + text, self.width - len(self.chat_prefix) + ) self._print(text) - + def function_call(self, call, result): if result and len(result) > 0: entry = f"{self.debugger_prompt}{call}\n{result}" else: entry = f"{self.debugger_prompt}{call}" self._print(entry) - diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index ab697d5..c73b733 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -9,7 +9,6 @@ from ..assistant.assistant import AbstractAssistantClient - class CopyingTextIOWrapper: """ File wrapper that will stash a copy of everything written. @@ -53,17 +52,17 @@ def __init__(self, log_filename, config, capture_streams=True): def _make_log(self): meta = { - 'time': datetime.now(), - 'command_line': " ".join(sys.argv), - 'uid': str(uuid.uuid4()), - 'config': self.config + "time": datetime.now(), + "command_line": " ".join(sys.argv), + "uid": str(uuid.uuid4()), + "config": self.config, } return { - 'steps':[], - 'meta':meta, - 'instructions':None, - 'stdout':self._stdout_wrapper.getvalue(), - 'stderr':self._stderr_wrapper.getvalue(), + "steps": [], + "meta": meta, + "instructions": None, + "stdout": self._stdout_wrapper.getvalue(), + "stderr": self._stderr_wrapper.getvalue(), } def _dump(self): @@ -71,29 +70,34 @@ def _dump(self): def total(key): return sum( - x['stats'][key] for x in log['steps'] if x['output']['type'] == 'chat' and 'stats' in x['output'] + x["stats"][key] + for x in log["steps"] + if x["output"]["type"] == "chat" and "stats" in x["output"] ) - log['meta']['total_tokens'] = total("tokens") - log['meta']['total_time'] = total("time") - log['meta']['total_cost'] = total("cost") + log["meta"]["total_tokens"] = total("tokens") + log["meta"]["total_time"] = total("time") + log["meta"]["total_cost"] = total("cost") print(f"*** Writing ChatDBG dialog log to {self._log_filename}") with open(self._log_filename, "a") as file: + def literal_presenter(dumper, data): if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar( + "tag:yaml.org,2002:str", data, style="|" + ) else: return dumper.represent_scalar("tag:yaml.org,2002:str", data) yaml.add_representer(str, literal_presenter) - yaml.dump([ log ], file, default_flow_style=False, indent=2) + yaml.dump([log], file, default_flow_style=False, indent=2) def begin_dialog(self, instructions): log = self._log assert log != None - log['instructions'] = instructions + log["instructions"] = instructions def end_dialog(self): if self._log != None: @@ -105,25 +109,33 @@ def begin_query(self, prompt, user_text): assert log != None assert self._current_chat == None self._current_chat = { - 'input':user_text, - 'prompt':prompt, - 'output': { 'type': 'chat', 'outputs': []} + "input": user_text, + "prompt": prompt, + "output": {"type": "chat", "outputs": []}, } def end_query(self, stats): log = self._log assert log != None assert self._current_chat != None - log['steps'] += [ self._current_chat ] + log["steps"] += [self._current_chat] self._current_chat = None def _post(self, text, kind): log = self._log assert log != None if self._current_chat != None: - self._current_chat['output']['outputs'].append({ 'type': 'text', 'output': f"*** {kind}: {text}"}) + self._current_chat["output"]["outputs"].append( + {"type": "text", "output": f"*** {kind}: {text}"} + ) else: - log['steps'].append({ 'type': 'call', 'input': f"*** {kind}", 'output': { 'type': 'text', 'output': text } }) + log["steps"].append( + { + "type": "call", + "input": f"*** {kind}", + "output": {"type": "text", "output": text}, + } + ) def warn(self, text): self._post(text, "Warning") @@ -135,12 +147,24 @@ def response(self, text): log = self._log assert log != None assert self._current_chat != None - self._current_chat['output']['outputs'].append({ 'type': 'text', 'output': text}) + self._current_chat["output"]["outputs"].append({"type": "text", "output": text}) def function_call(self, call, result): log = self._log assert log != None if self._current_chat != None: - self._current_chat['output']['outputs'].append({ 'type': 'call', 'input': call, 'output': { 'type': 'text', 'output': result }}) + self._current_chat["output"]["outputs"].append( + { + "type": "call", + "input": call, + "output": {"type": "text", "output": result}, + } + ) else: - log['steps'].append({ 'type': 'call', 'input': call, 'output': { 'type': 'text', 'output': result }}) + log["steps"].append( + { + "type": "call", + "input": call, + "output": {"type": "text", "output": result}, + } + ) diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/ipdb_util/config.py index b3a345b..181b754 100644 --- a/src/chatdbg/ipdb_util/config.py +++ b/src/chatdbg/ipdb_util/config.py @@ -23,7 +23,7 @@ class Chat(Configurable): debug = Bool(chat_get_env("debug", False), help="Log LLM calls").tag(config=True) log = Unicode(chat_get_env("log", "log.yaml"), help="The log file").tag(config=True) - + tag = Unicode(chat_get_env("tag", ""), help="Any extra info for log file").tag( config=True ) @@ -39,19 +39,19 @@ class Chat(Configurable): show_locals = Bool( chat_get_env("show_locals", True), help="show local var values in stacktrace" ).tag(config=True) - + show_libs = Bool( chat_get_env("show_libs", False), help="show library frames in stacktrace" ).tag(config=True) - + show_slices = Bool( chat_get_env("show_slices", True), help="support the `slice` command" ).tag(config=True) - + take_the_wheel = Bool( chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" ).tag(config=True) - + stream = Bool( chat_get_env("stream", False), help="Stream the response at it arrives" ).tag(config=True) @@ -72,4 +72,5 @@ def to_json(self): "stream": self.stream, } + chatdbg_config: Chat = Chat() diff --git a/src/chatdbg/ipdb_util/printer.py b/src/chatdbg/ipdb_util/printer.py index f0d2f95..2f0d8e6 100644 --- a/src/chatdbg/ipdb_util/printer.py +++ b/src/chatdbg/ipdb_util/printer.py @@ -3,8 +3,9 @@ from ..chatdbg_pdb import ChatDBG import sys + class Printer(AssistantPrinter): - + def __init__(self, message, error, log): self._message = message self._error = error @@ -21,15 +22,16 @@ def log(self, json_obj): def fail(self, message): print() - print(textwrap.wrap(message, width=70, initial_indent='*** ')) + print(textwrap.wrap(message, width=70, initial_indent="*** ")) sys.exit(1) def warn(self, message): print() - print(textwrap.wrap(message, width=70, initial_indent='*** ')) + print(textwrap.wrap(message, width=70, initial_indent="*** ")) + class StreamingPrinter(AssistantPrinter): - + def __init__(self, message, error): self.message = message self.error = error @@ -45,9 +47,9 @@ def log(self, json_obj): def fail(self, message): print() - print(textwrap.wrap(message, width=70, initial_indent='*** ')) + print(textwrap.wrap(message, width=70, initial_indent="*** ")) sys.exit(1) def warn(self, message): print() - print(textwrap.wrap(message, width=70, initial_indent='*** ')) + print(textwrap.wrap(message, width=70, initial_indent="*** ")) diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py index 601fa5f..0f95e08 100644 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ b/src/chatdbg/ipdb_util/streamwrap.py @@ -6,42 +6,43 @@ class StreamingTextWrapper: - def __init__(self, indent=' ', width=80): - self._buffer = '' # the raw text so far - self._wrapped = '' # the successfully wrapped text do far - self._pending = '' # the part after the last space in buffer -- has not been wrapped yet + def __init__(self, indent=" ", width=80): + self._buffer = "" # the raw text so far + self._wrapped = "" # the successfully wrapped text do far + self._pending = ( + "" # the part after the last space in buffer -- has not been wrapped yet + ) self._indent = indent self._width = width - len(indent) def append(self, text, flush=False): if flush: self._buffer += self._pending + text - self._pending = '' + self._pending = "" else: - text_bits = re.split(r'(\s+)', self._pending + text) + text_bits = re.split(r"(\s+)", self._pending + text) self._pending = text_bits[-1] - self._buffer += (''.join(text_bits[0:-1])) + self._buffer += "".join(text_bits[0:-1]) wrapped = word_wrap_except_code_blocks(self._buffer) wrapped = textwrap.indent(wrapped, self._indent, lambda _: True) - wrapped_delta = wrapped[len(self._wrapped):] + wrapped_delta = wrapped[len(self._wrapped) :] self._wrapped = wrapped return wrapped_delta def flush(self): if len(self._buffer) > 0: - result = self.append('\n', flush=True) + result = self.append("\n", flush=True) else: - result = self.append('', flush=True) - self._buffer = '' - self._wrapped = '' + result = self.append("", flush=True) + self._buffer = "" + self._wrapped = "" return result - -if __name__ == '__main__': - s = StreamingTextWrapper(3,20) +if __name__ == "__main__": + s = StreamingTextWrapper(3, 20) for x in sys.argv[1:]: - y = s.append(' ' + x) - print(y, end='', flush=True) + y = s.append(" " + x) + print(y, end="", flush=True) print(s.flush()) diff --git a/src/chatdbg/ipdb_util/text.py b/src/chatdbg/ipdb_util/text.py index bc9f7ef..9d31e0d 100644 --- a/src/chatdbg/ipdb_util/text.py +++ b/src/chatdbg/ipdb_util/text.py @@ -5,6 +5,7 @@ import numpy as np import textwrap + def make_arrow(pad): """generate the leading arrow in front of traceback or debugger""" if pad >= 2: @@ -117,6 +118,7 @@ def truncate_proportionally(text, maxlen=32000, top_proportion=0.5): return text[:pre] + "..." + text[len(text) - post :] return text + def word_wrap_except_code_blocks(text: str, width: int = 80) -> str: """ Wraps text except for code blocks for nice terminal formatting. @@ -132,11 +134,11 @@ def word_wrap_except_code_blocks(text: str, width: int = 80) -> str: Returns: The wrapped text. """ - blocks = text.split('```') + blocks = text.split("```") for i in range(len(blocks)): if i % 2 == 0: - paras = blocks[i].split('\n') - wrapped = [ textwrap.fill(para, width=width) for para in paras] + paras = blocks[i].split("\n") + wrapped = [textwrap.fill(para, width=width) for para in paras] blocks[i] = "\n".join(wrapped) - return '```'.join(blocks) + return "```".join(blocks) From 70342b5fbd833350604b35c52df435b592df1ef2 Mon Sep 17 00:00:00 2001 From: Stephen Freund Date: Wed, 27 Mar 2024 11:20:27 -0400 Subject: [PATCH 12/17] cleanup --- src/chatdbg/assistant/assistant-old.py | 243 ------------------------- src/chatdbg/assistant/assistant.py | 114 ++---------- src/chatdbg/chatdbg_pdb.py | 7 +- src/chatdbg/ipdb_util/chatlog.py | 9 +- src/chatdbg/ipdb_util/printer.py | 55 ------ 5 files changed, 20 insertions(+), 408 deletions(-) delete mode 100644 src/chatdbg/assistant/assistant-old.py delete mode 100644 src/chatdbg/ipdb_util/printer.py diff --git a/src/chatdbg/assistant/assistant-old.py b/src/chatdbg/assistant/assistant-old.py deleted file mode 100644 index 4366aa9..0000000 --- a/src/chatdbg/assistant/assistant-old.py +++ /dev/null @@ -1,243 +0,0 @@ -import atexit -import inspect -import json -import time -import sys - -import llm_utils -from openai import * -from pydantic import BaseModel - - -class Assistant: - """ - An Assistant is a wrapper around OpenAI's assistant API. Example usage: - - assistant = Assistant("Assistant Name", instructions, - model='gpt-4-1106-preview', debug=True) - assistant.add_function(my_func) - response = assistant.run(user_prompt) - - Name can be any name you want. - - If debug is True, it will create a log of all messages and JSON responses in - json.txt. - """ - - def __init__( - self, name, instructions, model="gpt-3.5-turbo-1106", timeout=30, debug=True - ): - if debug: - self.json = open(f"json.txt", "a") - else: - self.json = None - - try: - self.client = OpenAI(timeout=timeout) - except OpenAIError: - print("*** You need an OpenAI key to use this tool.") - print("*** You can get a key here: https://platform.openai.com/api-keys") - print("*** Set the environment variable OPENAI_API_KEY to your key value.") - sys.exit(-1) - - self.assistants = self.client.beta.assistants - self.threads = self.client.beta.threads - self.functions = dict() - - self.assistant = self.assistants.create( - name=name, instructions=instructions, model=model - ) - self.thread = self.threads.create() - - atexit.register(self._delete_assistant) - - def _delete_assistant(self): - if self.assistant != None: - try: - id = self.assistant.id - response = self.assistants.delete(id) - assert response.deleted - except OSError: - raise - except Exception as e: - print(f"*** Assistant {id} was not deleted ({e}).") - print("*** You can do so at https://platform.openai.com/assistants.") - - def add_function(self, function): - """ - Add a new function to the list of function tools for the assistant. - The function should have the necessary json spec as its pydoc string. - """ - function_json = json.loads(function.__doc__) - try: - name = function_json["name"] - self.functions[name] = function - - tools = [ - {"type": "function", "function": json.loads(function.__doc__)} - for function in self.functions.values() - ] - - self.assistants.update(self.assistant.id, tools=tools) - except OpenAIError as e: - print(f"*** OpenAI Error: {e}") - sys.exit(-1) - - def _make_call(self, tool_call): - name = tool_call.function.name - args = tool_call.function.arguments - - # There is a sketchy case that happens occasionally because - # the API produces a bad call... - try: - args = json.loads(args) - function = self.functions[name] - result = function(**args) - except OSError as e: - result = f"Error: {e}" - except Exception as e: - result = f"Ill-formed function call: {e}" - - return result - - def _print_messages(self, messages, client_print): - client_print() - for i, m in enumerate(messages): - message_text = m.content[0].text.value - if i == 0: - message_text = "(Message) " + message_text - client_print(message_text) - - def _wait_on_run(self, run, thread, client_print): - try: - while run.status == "queued" or run.status == "in_progress": - run = self.threads.runs.retrieve( - thread_id=thread.id, - run_id=run.id, - ) - time.sleep(0.5) - return run - finally: - if run.status == "in_progress": - client_print("Cancelling message that's in progress.") - self.threads.runs.cancel(thread_id=thread.id, run_id=run.id) - - def run(self, prompt, client_print=print): - """ - Give the prompt to the assistant and get the response, which may included - intermediate function calls. - All output is printed to the given file. - """ - start_time = time.perf_counter() - - try: - if self.assistant == None: - return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, - "model": self.assistant.model, - "cost": cost, - } - - assert len(prompt) <= 32768 - - message = self.threads.messages.create( - thread_id=self.thread.id, role="user", content=prompt - ) - self._log(message) - - last_printed_message_id = message.id - - run = self.threads.runs.create( - thread_id=self.thread.id, assistant_id=self.assistant.id - ) - self._log(run) - - run = self._wait_on_run(run, self.thread, client_print) - self._log(run) - - while run.status == "requires_action": - messages = self.threads.messages.list( - thread_id=self.thread.id, after=last_printed_message_id, order="asc" - ) - - mlist = list(messages) - if len(mlist) > 0: - self._print_messages(mlist, client_print) - last_printed_message_id = mlist[-1].id - client_print() - - outputs = [] - for tool_call in run.required_action.submit_tool_outputs.tool_calls: - output = self._make_call(tool_call) - self._log(output) - outputs += [{"tool_call_id": tool_call.id, "output": output}] - - try: - run = self.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, run_id=run.id, tool_outputs=outputs - ) - self._log(run) - except OSError as e: - self._log(run, f"FAILED to submit tool call results: {e}") - raise - except Exception as e: - self._log(run, f"FAILED to submit tool call results: {e}") - - run = self._wait_on_run(run, self.thread, client_print) - self._log(run) - - if run.status == "failed": - message = f"\n**Internal Failure ({run.last_error.code}):** {run.last_error.message}" - client_print(message) - self._log(run) - sys.exit(-1) - - messages = self.threads.messages.list( - thread_id=self.thread.id, after=last_printed_message_id, order="asc" - ) - self._print_messages(messages, client_print) - - end_time = time.perf_counter() - elapsed_time = end_time - start_time - - cost = llm_utils.calculate_cost( - run.usage.prompt_tokens, - run.usage.completion_tokens, - self.assistant.model, - ) - client_print() - client_print(f"[Cost: ~${cost:.2f} USD]") - - return { - "tokens": run.usage.total_tokens, - "prompt_tokens": run.usage.prompt_tokens, - "completion_tokens": run.usage.completion_tokens, - "model": self.assistant.model, - "cost": cost, - "time": elapsed_time, - "thread.id": self.thread.id, - "run.id": run.id, - "assistant.id": self.assistant.id, - } - except OpenAIError as e: - client_print(f"*** OpenAI Error: {e}") - sys.exit(-1) - - def _log(self, obj, title=""): - if self.json != None: - stack = inspect.stack() - caller_frame_record = stack[1] - lineno, function = caller_frame_record[2:4] - loc = f"{function}:{lineno}" - - print("-" * 70, file=self.json) - print(f"{loc} {title}", file=self.json) - if isinstance(obj, BaseModel): - json_obj = json.loads(obj.model_dump_json()) - else: - json_obj = obj - print(f"\n{json.dumps(json_obj, indent=2)}\n", file=self.json) - self.json.flush() - return obj diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 0c688d2..e879bbe 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -8,97 +8,7 @@ import textwrap import sys - -class AbstractAssistantClient: - - def begin_dialog(self, instructions): - pass - - def end_dialog(self): - pass - - def begin_query(self, prompt, **kwargs): - pass - - def end_query(self, stats): - pass - - def warn(self, text): - pass - - def fail(self, text): - pass - - def begin_stream(self): - pass - - def stream_delta(self, text): - pass - - def end_stream(self): - pass - - def response(self, text): - pass - - def function_call(self, call, result): - pass - - -class PrintingAssistantClient(AbstractAssistantClient): - def __init__(self, out=sys.stdout): - self.out = out - - def warn(self, text): - print(textwrap.indent(text, "*** "), file=self.out) - - def fail(self, text): - print(textwrap.indent(text, "*** "), file=self.out) - sys.exit(1) - - def begin_stream(self): - pass - - def stream_delta(self, text): - print(text, end="", file=self.out, flush=True) - - def end_stream(self): - pass - - def begin_query(self, prompt, **kwargs): - pass - - def end_query(self, stats): - pass - - def response(self, text): - if text != None: - print(text, file=self.out) - - def function_call(self, call, result): - if result and len(result) > 0: - entry = f"{call}\n{result}" - else: - entry = f"{call}" - print(entry, file=self.out) - - -class StreamingAssistantClient(PrintingAssistantClient): - def __init__(self, out=sys.stdout): - super().__init__(out) - - def begin_stream(self): - print("", flush=True) - - def stream_delta(self, text): - print(text, end="", file=self.out, flush=True) - - def end_stream(self): - print("", flush=True) - - def response(self, text): - pass - +from .listeners import Printer, StreamingPrinter class Assistant: def __init__( @@ -106,7 +16,7 @@ def __init__( instructions, model="gpt-3.5-turbo-1106", timeout=30, - clients=[PrintingAssistantClient()], + clients=[Printer()], functions=[], max_call_response_tokens=4096, debug=False, @@ -138,11 +48,11 @@ def __init__( def close(self): self._broadcast("end_dialog") - def _broadcast(self, method_name, *args, **kwargs): + def _broadcast(self, method_name, *args): for client in self._clients: method = getattr(client, method_name, None) if callable(method): - method(*args, **kwargs) + method(*args) def _check_model(self): result = litellm.validate_environment(self._model) @@ -224,18 +134,18 @@ def _make_call(self, tool_call) -> str: result = f"Ill-formed function call: {e}" return result - def query(self, prompt: str, **kwargs): + def query(self, prompt: str, extra = None): if self._stream_response: - return self._streamed_query(prompt=prompt, **kwargs) + return self._streamed_query(prompt=prompt, extra=extra) else: - return self._batch_query(prompt=prompt, **kwargs) + return self._batch_query(prompt=prompt, extra=extra) - def _batch_query(self, prompt: str, **kwargs): + def _batch_query(self, prompt: str, extra): start = time.time() cost = 0 try: - self._broadcast("begin_query", prompt, **kwargs) + self._broadcast("begin_query", prompt, extra) self._conversation.append({"role": "user", "content": prompt}) while True: @@ -273,12 +183,12 @@ def _batch_query(self, prompt: str, **kwargs): self._broadcast("fail", f"Internal Error: {e.__dict__}") sys.exit(1) - def _streamed_query(self, prompt: str, **kwargs): + def _streamed_query(self, prompt: str, extra = None): start = time.time() cost = 0 try: - self._broadcast("begin_query", prompt, **kwargs) + self._broadcast("begin_query", prompt, extra=extra) self._conversation.append({"role": "user", "content": prompt}) while True: @@ -414,7 +324,7 @@ def weather(location, unit="f"): return f"weather(location, unit)", "Sunny and 72 degrees." a = Assistant( - "You generate text.", clients=[StreamingAssistantClient()], functions=[weather] + "You generate text.", clients=[StreamingPrinter()], functions=[weather] ) x = a.query( "tell me what model you are before making any function calls. And what's the weather in Boston?", diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index f8be2d0..23495c6 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -17,7 +17,7 @@ from chatdbg.ipdb_util.capture import CaptureInput -from .assistant.assistant import Assistant, AbstractAssistantClient +from .assistant.assistant import Assistant, AbsAssistantListener from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper from .ipdb_util.config import Chat, chatdbg_config from .ipdb_util.locals import * @@ -711,7 +711,7 @@ def slice(name): ############################################################### -class ChatAssistantClient(AbstractAssistantClient): +class ChatAssistantClient(AbsAssistantListener): def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): self.out = out self.debugger_prompt = debugger_prompt @@ -722,7 +722,7 @@ def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): # Call backs - def begin_query(self, prompt="", user_text=""): + def begin_query(self, prompt, extra): pass def end_query(self, stats): @@ -745,7 +745,6 @@ def fail(self, text): def begin_stream(self): self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) self._at_start = True - # print(self._stream_wrapper.append("\n", False), end='', flush=True, file=self.out) def stream_delta(self, text): if self._at_start: diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index c73b733..0611534 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -6,7 +6,7 @@ import yaml from pydantic import BaseModel from typing import List, Union, Optional -from ..assistant.assistant import AbstractAssistantClient +from ..assistant.assistant import AbsAssistantListener class CopyingTextIOWrapper: @@ -33,7 +33,7 @@ def __getattr__(self, attr): return getattr(self.file, attr) -class ChatDBGLog(AbstractAssistantClient): +class ChatDBGLog(AbsAssistantListener): def __init__(self, log_filename, config, capture_streams=True): self._log_filename = log_filename @@ -104,12 +104,12 @@ def end_dialog(self): self._dump() self._log = self._make_log() - def begin_query(self, prompt, user_text): + def begin_query(self, prompt, extra): log = self._log assert log != None assert self._current_chat == None self._current_chat = { - "input": user_text, + "input": extra, "prompt": prompt, "output": {"type": "chat", "outputs": []}, } @@ -119,6 +119,7 @@ def end_query(self, stats): assert log != None assert self._current_chat != None log["steps"] += [self._current_chat] + log['stats'] = stats self._current_chat = None def _post(self, text, kind): diff --git a/src/chatdbg/ipdb_util/printer.py b/src/chatdbg/ipdb_util/printer.py deleted file mode 100644 index 2f0d8e6..0000000 --- a/src/chatdbg/ipdb_util/printer.py +++ /dev/null @@ -1,55 +0,0 @@ -import textwrap -from ..assistant.assistant import AssistantPrinter -from ..chatdbg_pdb import ChatDBG -import sys - - -class Printer(AssistantPrinter): - - def __init__(self, message, error, log): - self._message = message - self._error = error - self._log = log - - def stream(self, text): - print(text, flush=True, end=None) - - def message(self, text): - print(text, flush=True) - - def log(self, json_obj): - pass - - def fail(self, message): - print() - print(textwrap.wrap(message, width=70, initial_indent="*** ")) - sys.exit(1) - - def warn(self, message): - print() - print(textwrap.wrap(message, width=70, initial_indent="*** ")) - - -class StreamingPrinter(AssistantPrinter): - - def __init__(self, message, error): - self.message = message - self.error = error - - def stream(self, text): - print(text, flush=True, end=None) - - def message(self, text): - print(text, flush=True) - - def log(self, json_obj): - pass - - def fail(self, message): - print() - print(textwrap.wrap(message, width=70, initial_indent="*** ")) - sys.exit(1) - - def warn(self, message): - print() - print(textwrap.wrap(message, width=70, initial_indent="*** ")) From dcd7833ab930088f4387f48df5411871385d61c2 Mon Sep 17 00:00:00 2001 From: Stephen Freund Date: Wed, 27 Mar 2024 12:36:52 -0400 Subject: [PATCH 13/17] files --- src/chatdbg/assistant/listeners.py | 0 src/chatdbg/assistant/test.py | 41 ++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 src/chatdbg/assistant/listeners.py create mode 100644 src/chatdbg/assistant/test.py diff --git a/src/chatdbg/assistant/listeners.py b/src/chatdbg/assistant/listeners.py new file mode 100644 index 0000000..e69de29 diff --git a/src/chatdbg/assistant/test.py b/src/chatdbg/assistant/test.py new file mode 100644 index 0000000..480849f --- /dev/null +++ b/src/chatdbg/assistant/test.py @@ -0,0 +1,41 @@ +from .assistant import Assistant +from .listeners import StreamingPrinter, Printer + +if __name__ == "__main__": + + def weather(location, unit="f"): + """ + { + "name": "get_weather", + "description": "Determine weather in my location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + } + """ + return f"weather({location}, {unit})", "Sunny and 72 degrees." + + a = Assistant( + "You generate text.", clients=[StreamingPrinter()], functions=[weather] + ) + x = a.query( + "tell me what model you are before making any function calls. And what's the weather in Boston?", + stream=True, + ) + print(x) From e7221481d39bc99c9eb7c7f450df220bae586606 Mon Sep 17 00:00:00 2001 From: Stephen Freund Date: Wed, 27 Mar 2024 13:53:07 -0400 Subject: [PATCH 14/17] biffs while moving around definitions --- src/chatdbg/assistant/assistant.py | 41 +------------ src/chatdbg/assistant/listeners.py | 94 ++++++++++++++++++++++++++++++ src/chatdbg/chatdbg_pdb.py | 5 +- src/chatdbg/ipdb_util/chatlog.py | 2 +- src/chatdbg/ipdb_util/config.py | 24 ++++---- 5 files changed, 111 insertions(+), 55 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index e879bbe..c7064b9 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -188,7 +188,7 @@ def _streamed_query(self, prompt: str, extra = None): cost = 0 try: - self._broadcast("begin_query", prompt, extra=extra) + self._broadcast("begin_query", prompt, extra) self._conversation.append({"role": "user", "content": prompt}) while True: @@ -292,42 +292,3 @@ def _add_function_results_to_conversation(self, response_message): # Warning: potential infinite loop. self._broadcast("warn", f"Error processing tool calls: {e}") - -if __name__ == "__main__": - - def weather(location, unit="f"): - """ - { - "name": "get_weather", - "description": "Determine weather in my location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } - }, - "required": [ - "location" - ] - } - } - """ - return f"weather(location, unit)", "Sunny and 72 degrees." - - a = Assistant( - "You generate text.", clients=[StreamingPrinter()], functions=[weather] - ) - x = a.query( - "tell me what model you are before making any function calls. And what's the weather in Boston?", - stream=True, - ) - print(x) diff --git a/src/chatdbg/assistant/listeners.py b/src/chatdbg/assistant/listeners.py index e69de29..57ef9f8 100644 --- a/src/chatdbg/assistant/listeners.py +++ b/src/chatdbg/assistant/listeners.py @@ -0,0 +1,94 @@ + +import sys +import textwrap + + +class AbsAssistantListener: + + def begin_dialog(self, instructions): + pass + + def end_dialog(self): + pass + + def begin_query(self, prompt, extra): + pass + + def end_query(self, stats): + pass + + def warn(self, text): + pass + + def fail(self, text): + pass + + def begin_stream(self): + pass + + def stream_delta(self, text): + pass + + def end_stream(self): + pass + + def response(self, text): + pass + + def function_call(self, call, result): + pass + + +class Printer(AbsAssistantListener): + def __init__(self, out=sys.stdout): + self.out = out + + def warn(self, text): + print(textwrap.indent(text, "*** "), file=self.out) + + def fail(self, text): + print(textwrap.indent(text, "*** "), file=self.out) + sys.exit(1) + + def begin_stream(self): + pass + + def stream_delta(self, text): + print(text, end="", file=self.out, flush=True) + + def end_stream(self): + pass + + def begin_query(self, prompt, extra): + pass + + def end_query(self, stats): + pass + + def response(self, text): + if text != None: + print(text, file=self.out) + + def function_call(self, call, result): + if result and len(result) > 0: + entry = f"{call}\n{result}" + else: + entry = f"{call}" + print(entry, file=self.out) + + +class StreamingPrinter(Printer): + def __init__(self, out=sys.stdout): + super().__init__(out) + + def begin_stream(self): + print("", flush=True) + + def stream_delta(self, text): + print(text, end="", file=self.out, flush=True) + + def end_stream(self): + print("", flush=True) + + def response(self, text): + pass diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 23495c6..392b67d 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -17,7 +17,8 @@ from chatdbg.ipdb_util.capture import CaptureInput -from .assistant.assistant import Assistant, AbsAssistantListener +from .assistant.assistant import Assistant +from .assistant.listeners import AbsAssistantListener from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper from .ipdb_util.config import Chat, chatdbg_config from .ipdb_util.locals import * @@ -577,7 +578,7 @@ def do_chat(self, arg): if self._assistant == None: self._make_assistant() - stats = self._assistant.query(full_prompt, user_text=arg) + stats = self._assistant.query(full_prompt, extra=arg) self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index 0611534..afc92c0 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -6,7 +6,7 @@ import yaml from pydantic import BaseModel from typing import List, Union, Optional -from ..assistant.assistant import AbsAssistantListener +from ..assistant.listeners import AbsAssistantListener class CopyingTextIOWrapper: diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/ipdb_util/config.py index 181b754..5f731a6 100644 --- a/src/chatdbg/ipdb_util/config.py +++ b/src/chatdbg/ipdb_util/config.py @@ -4,7 +4,7 @@ from traitlets.config import Configurable -def chat_get_env(option_name, default_value): +def _chat_get_env(option_name, default_value): env_name = "CHATDBG_" + option_name.upper() v = os.getenv(env_name, str(default_value)) if type(default_value) == int: @@ -17,43 +17,43 @@ def chat_get_env(option_name, default_value): class Chat(Configurable): model = Unicode( - chat_get_env("model", "gpt-4-1106-preview"), help="The LLM model" + _chat_get_env("model", "gpt-4-1106-preview"), help="The LLM model" ).tag(config=True) - debug = Bool(chat_get_env("debug", False), help="Log LLM calls").tag(config=True) + debug = Bool(_chat_get_env("debug", False), help="Log LLM calls").tag(config=True) - log = Unicode(chat_get_env("log", "log.yaml"), help="The log file").tag(config=True) + log = Unicode(_chat_get_env("log", "log.yaml"), help="The log file").tag(config=True) - tag = Unicode(chat_get_env("tag", ""), help="Any extra info for log file").tag( + tag = Unicode(_chat_get_env("tag", ""), help="Any extra info for log file").tag( config=True ) rc_lines = Unicode( - chat_get_env("rc_lines", "[]"), help="lines to run at startup" + _chat_get_env("rc_lines", "[]"), help="lines to run at startup" ).tag(config=True) context = Int( - chat_get_env("context", 10), + _chat_get_env("context", 10), help="lines of source code to show when displaying stacktrace information", ).tag(config=True) show_locals = Bool( - chat_get_env("show_locals", True), help="show local var values in stacktrace" + _chat_get_env("show_locals", True), help="show local var values in stacktrace" ).tag(config=True) show_libs = Bool( - chat_get_env("show_libs", False), help="show library frames in stacktrace" + _chat_get_env("show_libs", False), help="show library frames in stacktrace" ).tag(config=True) show_slices = Bool( - chat_get_env("show_slices", True), help="support the `slice` command" + _chat_get_env("show_slices", True), help="support the `slice` command" ).tag(config=True) take_the_wheel = Bool( - chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" + _chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" ).tag(config=True) stream = Bool( - chat_get_env("stream", False), help="Stream the response at it arrives" + _chat_get_env("stream", False), help="Stream the response at it arrives" ).tag(config=True) def to_json(self): From dcce56c56944771bb349e8f7b46460b64d034aa6 Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Wed, 27 Mar 2024 19:05:49 -0400 Subject: [PATCH 15/17] clean up --- src/chatdbg/assistant/assistant.py | 59 ++++---- src/chatdbg/assistant/listeners.py | 60 ++++---- src/chatdbg/assistant/test.py | 31 ++-- src/chatdbg/chatdbg_pdb.py | 222 +++++++++++++++-------------- src/chatdbg/ipdb_util/capture.py | 24 ++++ src/chatdbg/ipdb_util/chatlog.py | 46 ++---- src/chatdbg/ipdb_util/config.py | 33 +++-- 7 files changed, 259 insertions(+), 216 deletions(-) diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index c7064b9..6d390f7 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -10,6 +10,7 @@ from .listeners import Printer, StreamingPrinter + class Assistant: def __init__( self, @@ -20,7 +21,7 @@ def __init__( functions=[], max_call_response_tokens=4096, debug=False, - stream_response=False, + stream=False, ): if debug: log_file = open(f"chatdbg.log", "w") @@ -40,7 +41,7 @@ def __init__( self._timeout = timeout self._conversation = [{"role": "system", "content": instructions}] self._max_call_response_tokens = max_call_response_tokens - self._stream_response = stream_response + self._stream = stream self._check_model() self._broadcast("begin_dialog", instructions) @@ -48,6 +49,17 @@ def __init__( def close(self): self._broadcast("end_dialog") + def query(self, prompt: str, user_text): + """ + Send a query to the LLM. + - prompt is the prompt to send. + - user_text is what the user typed (which may or not be the same as prompt) + """ + if self._stream: + return self._streamed_query(prompt, user_text) + else: + return self._batch_query(prompt, user_text) + def _broadcast(self, method_name, *args): for client in self._clients: method = getattr(client, method_name, None) @@ -126,7 +138,7 @@ def _make_call(self, tool_call) -> str: args = json.loads(tool_call.function.arguments) function = self._functions[name] call, result = function["function"](**args) - self._broadcast("function_call", call, result) + self._broadcast("on_function_call", call, result) except OSError as e: # function produced some error -- move this to client??? result = f"Error: {e}" @@ -134,18 +146,12 @@ def _make_call(self, tool_call) -> str: result = f"Ill-formed function call: {e}" return result - def query(self, prompt: str, extra = None): - if self._stream_response: - return self._streamed_query(prompt=prompt, extra=extra) - else: - return self._batch_query(prompt=prompt, extra=extra) - - def _batch_query(self, prompt: str, extra): + def _batch_query(self, prompt: str, user_text): start = time.time() cost = 0 try: - self._broadcast("begin_query", prompt, extra) + self._broadcast("on_begin_query", prompt, user_text) self._conversation.append({"role": "user", "content": prompt}) while True: @@ -153,7 +159,7 @@ def _batch_query(self, prompt: str, extra): self._conversation, self._model ) - completion = self.completion() + completion = self._completion() cost += litellm.completion_cost(completion) @@ -161,7 +167,9 @@ def _batch_query(self, prompt: str, extra): self._conversation.append(response_message) if response_message.content: - self._broadcast("response", "(Message) " + response_message.content) + self._broadcast( + "on_response", "(Message) " + response_message.content + ) if completion.choices[0].finish_reason == "tool_calls": self._add_function_results_to_conversation(response_message) @@ -177,18 +185,18 @@ def _batch_query(self, prompt: str, extra): "prompt_tokens": completion.usage.prompt_tokens, "completion_tokens": completion.usage.completion_tokens, } - self._broadcast("end_query", stats) + self._broadcast("on_end_query", stats) return stats except openai.OpenAIError as e: self._broadcast("fail", f"Internal Error: {e.__dict__}") sys.exit(1) - def _streamed_query(self, prompt: str, extra = None): + def _streamed_query(self, prompt: str, user_text): start = time.time() cost = 0 try: - self._broadcast("begin_query", prompt, extra) + self._broadcast("on_begin_query", prompt, user_text) self._conversation.append({"role": "user", "content": prompt}) while True: @@ -197,21 +205,23 @@ def _streamed_query(self, prompt: str, extra = None): ) # print("\n".join([str(x) for x in self._conversation])) - stream = self.completion(stream=True) + stream = self._completion(stream=True) # litellm is broken for new GPT models that have content before calls, so... # stream the response, collecting the tool_call parts separately from the content - self._broadcast("begin_stream") + self._broadcast("on_begin_stream") chunks = [] tool_chunks = [] for chunk in stream: chunks.append(chunk) if chunk.choices[0].delta.content != None: - self._broadcast("stream_delta", chunk.choices[0].delta.content) + self._broadcast( + "on_stream_delta", chunk.choices[0].delta.content + ) else: tool_chunks.append(chunk) - self._broadcast("end_stream") + self._broadcast("on_end_stream") # compute for the part that litellm gives back. completion = litellm.stream_chunk_builder( @@ -226,7 +236,9 @@ def _streamed_query(self, prompt: str, extra = None): self._conversation.append(response_message) if response_message.content != None: - self._broadcast("response", "(Message) " + response_message.content) + self._broadcast( + "on_response", "(Message) " + response_message.content + ) if completion.choices[0].finish_reason == "tool_calls": # create a message with just the tool calls, append that to the conversation, and generate the responses. @@ -253,13 +265,13 @@ def _streamed_query(self, prompt: str, extra = None): "prompt_tokens": completion.usage.prompt_tokens, "completion_tokens": completion.usage.completion_tokens, } - self._broadcast("end_query", stats) + self._broadcast("on_end_query", stats) return stats except openai.OpenAIError as e: self._broadcast("fail", f"Internal Error: {e.__dict__}") sys.exit(1) - def completion(self, stream=False): + def _completion(self, stream=False): return litellm.completion( model=self._model, messages=self._conversation, @@ -291,4 +303,3 @@ def _add_function_results_to_conversation(self, response_message): except Exception as e: # Warning: potential infinite loop. self._broadcast("warn", f"Error processing tool calls: {e}") - diff --git a/src/chatdbg/assistant/listeners.py b/src/chatdbg/assistant/listeners.py index 57ef9f8..95ae435 100644 --- a/src/chatdbg/assistant/listeners.py +++ b/src/chatdbg/assistant/listeners.py @@ -1,45 +1,55 @@ - import sys import textwrap -class AbsAssistantListener: +class BaseAssistantListener: + """ + Events that the Assistant generates. Override these for the client. + """ + + # Dialogs capture 1 or more queries. - def begin_dialog(self, instructions): + def on_begin_dialog(self, instructions): pass - def end_dialog(self): + def on_end_dialog(self): pass - def begin_query(self, prompt, extra): + # Events for a single query + + def on_begin_query(self, prompt, user_text): pass - def end_query(self, stats): + def on_response(self, text): pass - def warn(self, text): + def on_function_call(self, call, result): pass - def fail(self, text): + def on_end_query(self, stats): pass - def begin_stream(self): + # For clients wishing to stream responses + + def on_begin_stream(self): pass - def stream_delta(self, text): + def on_stream_delta(self, text): pass - def end_stream(self): + def on_end_stream(self): pass - def response(self, text): + # Notifications of non-fatal / fatal problems + + def warn(self, text): pass - def function_call(self, call, result): + def fail(self, text): pass -class Printer(AbsAssistantListener): +class Printer(BaseAssistantListener): def __init__(self, out=sys.stdout): self.out = out @@ -50,26 +60,26 @@ def fail(self, text): print(textwrap.indent(text, "*** "), file=self.out) sys.exit(1) - def begin_stream(self): + def on_begin_stream(self): pass - def stream_delta(self, text): + def on_stream_delta(self, text): print(text, end="", file=self.out, flush=True) - def end_stream(self): + def on_end_stream(self): pass - def begin_query(self, prompt, extra): + def on_begin_query(self, prompt, user_text): pass - def end_query(self, stats): + def on_end_query(self, stats): pass - def response(self, text): + def on_response(self, text): if text != None: print(text, file=self.out) - def function_call(self, call, result): + def on_function_call(self, call, result): if result and len(result) > 0: entry = f"{call}\n{result}" else: @@ -81,14 +91,14 @@ class StreamingPrinter(Printer): def __init__(self, out=sys.stdout): super().__init__(out) - def begin_stream(self): + def on_begin_stream(self): print("", flush=True) - def stream_delta(self, text): + def on_stream_delta(self, text): print(text, end="", file=self.out, flush=True) - def end_stream(self): + def on_end_stream(self): print("", flush=True) - def response(self, text): + def on_response(self, text): pass diff --git a/src/chatdbg/assistant/test.py b/src/chatdbg/assistant/test.py index 480849f..984c2ec 100644 --- a/src/chatdbg/assistant/test.py +++ b/src/chatdbg/assistant/test.py @@ -1,9 +1,25 @@ from .assistant import Assistant from .listeners import StreamingPrinter, Printer -if __name__ == "__main__": - def weather(location, unit="f"): +class AssistantTest: + + def __init__(self): + self.a = Assistant( + "You generate text.", + clients=[StreamingPrinter()], + functions=[self.weather], + stream=True, + ) + + def run(self): + x = self.a.query( + "tell me what model you are before making any function calls. And what's the weather in Boston?", + None, + ) + print(x) + + def weather(self, location, unit="f"): """ { "name": "get_weather", @@ -31,11 +47,6 @@ def weather(location, unit="f"): """ return f"weather({location}, {unit})", "Sunny and 72 degrees." - a = Assistant( - "You generate text.", clients=[StreamingPrinter()], functions=[weather] - ) - x = a.query( - "tell me what model you are before making any function calls. And what's the weather in Boston?", - stream=True, - ) - print(x) + +t = AssistantTest() +t.run() diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 392b67d..f218b8a 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -15,12 +15,12 @@ import IPython from traitlets import TraitError -from chatdbg.ipdb_util.capture import CaptureInput +from chatdbg.ipdb_util.capture import CaptureInput, CaptureOutput from .assistant.assistant import Assistant -from .assistant.listeners import AbsAssistantListener -from .ipdb_util.chatlog import ChatDBGLog, CopyingTextIOWrapper -from .ipdb_util.config import Chat, chatdbg_config +from .assistant.listeners import BaseAssistantListener +from .ipdb_util.chatlog import ChatDBGLog +from .ipdb_util.config import chatdbg_config from .ipdb_util.locals import * from .ipdb_util.prompts import pdb_instructions from .ipdb_util.streamwrap import StreamingTextWrapper @@ -30,10 +30,10 @@ def load_ipython_extension(ipython): global chatdbg_config from chatdbg.chatdbg_pdb import ChatDBG - from chatdbg.ipdb_util.config import Chat, chatdbg_config + from chatdbg.ipdb_util.config import ChatDBGConfig, chatdbg_config ipython.InteractiveTB.debugger_cls = ChatDBG - chatdbg_config = Chat(config=ipython.config) + chatdbg_config = ChatDBGConfig(config=ipython.config) print("*** Loaded ChatDBG ***") @@ -80,20 +80,7 @@ def __init__(self, *args, **kwargs): sys.stdin = CaptureInput(sys.stdin) - # Only use flow when we are in jupyter or using stdin in ipython. In both - # cases, there will be no python file at the start of argv after the - # ipython commands. - self._supports_flow = chatdbg_config.show_slices - if self._supports_flow: - if ChatDBGSuper is not IPython.core.debugger.InterruptiblePdb: - for arg in sys.argv: - if arg.endswith("ipython") or arg.endswith("ipython3"): - continue - if arg.startswith("-"): - continue - if Path(arg).suffix in [".py", ".ipy"]: - self._supports_flow = False - break + self._supports_flow = self.can_support_flow() self.do_context(chatdbg_config.context) self.rcLines += ast.literal_eval(chatdbg_config.rc_lines) @@ -111,6 +98,23 @@ def _close_assistant(self): if self._assistant != None: self._assistant.close() + def can_support_flow(self): + # Only use flow when we are in jupyter or using stdin in ipython. In both + # cases, there will be no python file at the start of argv after the + # ipython commands. + if chatdbg_config.show_slices: + if ChatDBGSuper is not IPython.core.debugger.InterruptiblePdb: + for arg in sys.argv: + if arg.endswith("ipython") or arg.endswith("ipython3"): + continue + if arg.startswith("-"): + continue + if Path(arg).suffix in [".py", ".ipy"]: + return False + return True + else: + return False + def _is_user_frame(self, frame): if not self._is_user_file(frame.f_code.co_filename): return False @@ -120,11 +124,13 @@ def _is_user_frame(self, frame): def _is_user_file(self, file_name): if file_name.endswith(".pyx"): return False - if file_name == "": + elif file_name == "": return False + for prefix in _user_file_prefixes: if file_name.startswith(prefix): return True + return False def format_stack_trace(self, context=None): @@ -199,7 +205,7 @@ def onecmd(self, line: str) -> bool: # blank -- let super call back to into onecmd return super().onecmd(line) else: - hist_file = CopyingTextIOWrapper(self.stdout) + hist_file = CaptureOutput(self.stdout) self.stdout = hist_file try: self.was_chat_or_renew = False @@ -208,7 +214,7 @@ def onecmd(self, line: str) -> bool: self.stdout = hist_file.getfile() output = strip_color(hist_file.getvalue()) if not self.was_chat_or_renew: - self._log.function_call(line, output) + self._log.on_function_call(line, output) if line.split(" ")[0] not in [ "hist", "test_prompt", @@ -313,9 +319,7 @@ def do_info(self, arg): # didn't find anything if obj == None: self.message(f"No name `{arg}` is visible in the current frame.") - return - - if self._is_user_file(inspect.getfile(obj)): + elif self._is_user_file(inspect.getfile(obj)): self.do_source(x) else: self.do_pydoc(x) @@ -578,7 +582,7 @@ def do_chat(self, arg): if self._assistant == None: self._make_assistant() - stats = self._assistant.query(full_prompt, extra=arg) + stats = self._assistant.query(full_prompt, user_text=arg) self.message(f"\n[Cost: ~${stats['cost']:.2f} USD]") @@ -611,83 +615,12 @@ def do_config(self, arg): self.error(f"{e}") def _make_assistant(self): - def info(value): - """ - { - "name": "info", - "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", - "parameters": { - "type": "object", - "properties": { - "value": { - "type": "string", - "description": "The reference to get the information for." - } - }, - "required": [ "value" ] - } - } - """ - command = f"info {value}" - result = self._capture_onecmd(command) - return command, truncate_proportionally(result, top_proportion=1) - - def debug(command): - """ - { - "name": "debug", - "description": "Run a pdb command and get the response.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The pdb command to run." - } - }, - "required": [ "command" ] - } - } - """ - cmd = command if command != "list" else "ll" - # old_curframe = self.curframe - result = self._capture_onecmd(cmd) - - # help the LLM know where it is... - # if old_curframe != self.curframe: - # result += strip_color(self._stack_prompt()) - - return command, truncate_proportionally( - result, maxlen=8000, top_proportion=0.9 - ) - - def slice(name): - """ - { - "name": "slice", - "description": "Return the code to compute a global variable used in the current frame", - "parameters": { - "type": "object", - "properties": { - "name": { - "type": "string", - "description": "The variable to look at." - } - }, - "required": [ "name" ] - } - } - """ - command = f"slice {name}" - result = self._capture_onecmd(command) - return command, truncate_proportionally(result, top_proportion=0.5) - instruction_prompt = self._ip_instructions() if chatdbg_config.take_the_wheel: - functions = [debug, info] + functions = [self.debug, self.info] if self._supports_flow: - functions += [slice] + functions += [self.slice] else: functions = [] @@ -696,7 +629,7 @@ def slice(name): model=chatdbg_config.model, debug=chatdbg_config.debug, functions=functions, - stream_response=chatdbg_config.stream, + stream=chatdbg_config.stream, clients=[ ChatAssistantClient( self.stdout, @@ -709,10 +642,81 @@ def slice(name): ], ) + ### Callbacks for LLM + + def info(self, value): + """ + { + "name": "info", + "description": "Get the documentation and source code for a reference, which may be a variable, function, method reference, field reference, or dotted reference visible in the current frame. Examples include n, e.n where e is an expression, and t.n where t is a type.", + "parameters": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The reference to get the information for." + } + }, + "required": [ "value" ] + } + } + """ + command = f"info {value}" + result = self._capture_onecmd(command) + return command, truncate_proportionally(result, top_proportion=1) + + def debug(self, command): + """ + { + "name": "debug", + "description": "Run a pdb command and get the response.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The pdb command to run." + } + }, + "required": [ "command" ] + } + } + """ + cmd = command if command != "list" else "ll" + # old_curframe = self.curframe + result = self._capture_onecmd(cmd) + + # help the LLM know where it is... + # if old_curframe != self.curframe: + # result += strip_color(self._stack_prompt()) + + return command, truncate_proportionally(result, maxlen=8000, top_proportion=0.9) + + def slice(self, name): + """ + { + "name": "slice", + "description": "Return the code to compute a global variable used in the current frame", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The variable to look at." + } + }, + "required": [ "name" ] + } + } + """ + command = f"slice {name}" + result = self._capture_onecmd(command) + return command, truncate_proportionally(result, top_proportion=0.5) + ############################################################### -class ChatAssistantClient(AbsAssistantListener): +class ChatAssistantClient(BaseAssistantListener): def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): self.out = out self.debugger_prompt = debugger_prompt @@ -723,10 +727,10 @@ def __init__(self, out, debugger_prompt, chat_prefix, width, stream=False): # Call backs - def begin_query(self, prompt, extra): + def on_begin_query(self, prompt, user_text): pass - def end_query(self, stats): + def on_end_query(self, stats): pass def _print(self, text, **kwargs): @@ -743,11 +747,11 @@ def fail(self, text): self._print(textwrap.indent(text, "*** ")) sys.exit(1) - def begin_stream(self): + def on_begin_stream(self): self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) self._at_start = True - def stream_delta(self, text): + def on_stream_delta(self, text): if self._at_start: self._at_start = False print( @@ -760,17 +764,17 @@ def stream_delta(self, text): self._stream_wrapper.append(text, False), end="", flush=True, file=self.out ) - def end_stream(self): + def on_end_stream(self): print(self._stream_wrapper.flush(), end="", flush=True, file=self.out) - def response(self, text): + def on_response(self, text): if not self._stream and text != None: text = word_wrap_except_code_blocks( text, self.width - len(self.chat_prefix) ) self._print(text) - def function_call(self, call, result): + def on_function_call(self, call, result): if result and len(result) > 0: entry = f"{self.debugger_prompt}{call}\n{result}" else: diff --git a/src/chatdbg/ipdb_util/capture.py b/src/chatdbg/ipdb_util/capture.py index 21c75ba..c4fcf53 100644 --- a/src/chatdbg/ipdb_util/capture.py +++ b/src/chatdbg/ipdb_util/capture.py @@ -30,3 +30,27 @@ def read(self, *args, **kwargs): def get_captured_input(self): return self.capture_buffer.getvalue() + + +class CaptureOutput: + """ + File wrapper that will stash a copy of everything written. + """ + + def __init__(self, file): + self.file = file + self.buffer = StringIO() + + def write(self, data): + self.buffer.write(data) + return self.file.write(data) + + def getvalue(self): + return self.buffer.getvalue() + + def getfile(self): + return self.file + + def __getattr__(self, attr): + # Delegate attribute access to the file object + return getattr(self.file, attr) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py index afc92c0..209cc59 100644 --- a/src/chatdbg/ipdb_util/chatlog.py +++ b/src/chatdbg/ipdb_util/chatlog.py @@ -6,41 +6,19 @@ import yaml from pydantic import BaseModel from typing import List, Union, Optional -from ..assistant.listeners import AbsAssistantListener +from chatdbg.ipdb_util.capture import CaptureOutput +from ..assistant.listeners import BaseAssistantListener -class CopyingTextIOWrapper: - """ - File wrapper that will stash a copy of everything written. - """ - def __init__(self, file): - self.file = file - self.buffer = StringIO() - - def write(self, data): - self.buffer.write(data) - return self.file.write(data) - - def getvalue(self): - return self.buffer.getvalue() - - def getfile(self): - return self.file - - def __getattr__(self, attr): - # Delegate attribute access to the file object - return getattr(self.file, attr) - - -class ChatDBGLog(AbsAssistantListener): +class ChatDBGLog(BaseAssistantListener): def __init__(self, log_filename, config, capture_streams=True): self._log_filename = log_filename self.config = config if capture_streams: - self._stdout_wrapper = CopyingTextIOWrapper(sys.stdout) - self._stderr_wrapper = CopyingTextIOWrapper(sys.stderr) + self._stdout_wrapper = CaptureOutput(sys.stdout) + self._stderr_wrapper = CaptureOutput(sys.stderr) sys.stdout = self._stdout_wrapper sys.stderr = self._stdout_wrapper else: @@ -94,17 +72,17 @@ def literal_presenter(dumper, data): yaml.add_representer(str, literal_presenter) yaml.dump([log], file, default_flow_style=False, indent=2) - def begin_dialog(self, instructions): + def on_begin_dialog(self, instructions): log = self._log assert log != None log["instructions"] = instructions - def end_dialog(self): + def on_end_dialog(self): if self._log != None: self._dump() self._log = self._make_log() - def begin_query(self, prompt, extra): + def on_begin_query(self, prompt, extra): log = self._log assert log != None assert self._current_chat == None @@ -114,12 +92,12 @@ def begin_query(self, prompt, extra): "output": {"type": "chat", "outputs": []}, } - def end_query(self, stats): + def on_end_query(self, stats): log = self._log assert log != None assert self._current_chat != None log["steps"] += [self._current_chat] - log['stats'] = stats + log["stats"] = stats self._current_chat = None def _post(self, text, kind): @@ -144,13 +122,13 @@ def warn(self, text): def fail(self, text): self._post(text, "Failure") - def response(self, text): + def on_response(self, text): log = self._log assert log != None assert self._current_chat != None self._current_chat["output"]["outputs"].append({"type": "text", "output": text}) - def function_call(self, call, result): + def on_function_call(self, call, result): log = self._log assert log != None if self._current_chat != None: diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/ipdb_util/config.py index 5f731a6..df08178 100644 --- a/src/chatdbg/ipdb_util/config.py +++ b/src/chatdbg/ipdb_util/config.py @@ -4,7 +4,7 @@ from traitlets.config import Configurable -def _chat_get_env(option_name, default_value): +def _chatdbg_get_env(option_name, default_value): env_name = "CHATDBG_" + option_name.upper() v = os.getenv(env_name, str(default_value)) if type(default_value) == int: @@ -15,45 +15,50 @@ def _chat_get_env(option_name, default_value): return v -class Chat(Configurable): +class ChatDBGConfig(Configurable): model = Unicode( - _chat_get_env("model", "gpt-4-1106-preview"), help="The LLM model" + _chatdbg_get_env("model", "gpt-4-1106-preview"), help="The LLM model" ).tag(config=True) - debug = Bool(_chat_get_env("debug", False), help="Log LLM calls").tag(config=True) + debug = Bool(_chatdbg_get_env("debug", False), help="Log LLM calls").tag( + config=True + ) - log = Unicode(_chat_get_env("log", "log.yaml"), help="The log file").tag(config=True) + log = Unicode(_chatdbg_get_env("log", "log.yaml"), help="The log file").tag( + config=True + ) - tag = Unicode(_chat_get_env("tag", ""), help="Any extra info for log file").tag( + tag = Unicode(_chatdbg_get_env("tag", ""), help="Any extra info for log file").tag( config=True ) rc_lines = Unicode( - _chat_get_env("rc_lines", "[]"), help="lines to run at startup" + _chatdbg_get_env("rc_lines", "[]"), help="lines to run at startup" ).tag(config=True) context = Int( - _chat_get_env("context", 10), + _chatdbg_get_env("context", 10), help="lines of source code to show when displaying stacktrace information", ).tag(config=True) show_locals = Bool( - _chat_get_env("show_locals", True), help="show local var values in stacktrace" + _chatdbg_get_env("show_locals", True), + help="show local var values in stacktrace", ).tag(config=True) show_libs = Bool( - _chat_get_env("show_libs", False), help="show library frames in stacktrace" + _chatdbg_get_env("show_libs", False), help="show library frames in stacktrace" ).tag(config=True) show_slices = Bool( - _chat_get_env("show_slices", True), help="support the `slice` command" + _chatdbg_get_env("show_slices", True), help="support the `slice` command" ).tag(config=True) take_the_wheel = Bool( - _chat_get_env("take_the_wheel", True), help="Let LLM take the wheel" + _chatdbg_get_env("take_the_wheel", True), help="Let LLM take the wheel" ).tag(config=True) stream = Bool( - _chat_get_env("stream", False), help="Stream the response at it arrives" + _chatdbg_get_env("stream", False), help="Stream the response at it arrives" ).tag(config=True) def to_json(self): @@ -73,4 +78,4 @@ def to_json(self): } -chatdbg_config: Chat = Chat() +chatdbg_config: ChatDBGConfig = ChatDBGConfig() From 71788db7b32e39e856a5fa486e6102939f0fd1aa Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Wed, 27 Mar 2024 19:08:17 -0400 Subject: [PATCH 16/17] clean up --- src/chatdbg/chatdbg_pdb.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index f218b8a..901211c 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -1,5 +1,5 @@ -import atexit import ast +import atexit import inspect import linecache import os @@ -21,10 +21,15 @@ from .assistant.listeners import BaseAssistantListener from .ipdb_util.chatlog import ChatDBGLog from .ipdb_util.config import chatdbg_config -from .ipdb_util.locals import * +from .ipdb_util.locals import extract_locals from .ipdb_util.prompts import pdb_instructions from .ipdb_util.streamwrap import StreamingTextWrapper -from .ipdb_util.text import * +from .ipdb_util.text import ( + format_limited, + strip_color, + truncate_proportionally, + word_wrap_except_code_blocks, +) def load_ipython_extension(ipython): From 8d1eb3bd6609a6a0371ad747bc1e314c55a0608e Mon Sep 17 00:00:00 2001 From: stephenfreund Date: Wed, 27 Mar 2024 20:10:55 -0400 Subject: [PATCH 17/17] cleanup --- src/chatdbg/__main__.py | 10 +- src/chatdbg/assistant/assistant.py | 33 +++--- src/chatdbg/assistant/listeners.py | 8 +- src/chatdbg/assistant/test.py | 2 +- src/chatdbg/chatdbg_pdb.py | 23 ++--- src/chatdbg/ipdb_util/capture.py | 56 ----------- src/chatdbg/ipdb_util/chatlog.py | 149 ---------------------------- src/chatdbg/ipdb_util/config.py | 81 --------------- src/chatdbg/ipdb_util/locals.py | 55 ---------- src/chatdbg/ipdb_util/prompts.py | 83 ---------------- src/chatdbg/ipdb_util/streamwrap.py | 48 --------- src/chatdbg/ipdb_util/text.py | 144 --------------------------- 12 files changed, 39 insertions(+), 653 deletions(-) delete mode 100644 src/chatdbg/ipdb_util/capture.py delete mode 100644 src/chatdbg/ipdb_util/chatlog.py delete mode 100644 src/chatdbg/ipdb_util/config.py delete mode 100644 src/chatdbg/ipdb_util/locals.py delete mode 100644 src/chatdbg/ipdb_util/prompts.py delete mode 100644 src/chatdbg/ipdb_util/streamwrap.py delete mode 100644 src/chatdbg/ipdb_util/text.py diff --git a/src/chatdbg/__main__.py b/src/chatdbg/__main__.py index 743dc6d..f876884 100644 --- a/src/chatdbg/__main__.py +++ b/src/chatdbg/__main__.py @@ -1,6 +1,6 @@ import ipdb from chatdbg.chatdbg_pdb import ChatDBG -from chatdbg.ipdb_util.config import chatdbg_config +from chatdbg.util.config import chatdbg_config import sys import getopt @@ -20,10 +20,10 @@ Option -m is available only in Python 3.7 and later. ChatDBG-specific options may appear anywhere before pyfile: - -debug dump the LLM messages to a chatdbg.log - -log file where to write the log of the debugging session - -model model the LLM model to use. - -stream stream responses from the LLM + --debug dump the LLM messages to a chatdbg.log + --log file where to write the log of the debugging session + --model model the LLM model to use. + --stream stream responses from the LLM """ diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 6d390f7..ac91e5e 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -8,7 +8,7 @@ import textwrap import sys -from .listeners import Printer, StreamingPrinter +from .listeners import Printer class Assistant: @@ -17,7 +17,7 @@ def __init__( instructions, model="gpt-3.5-turbo-1106", timeout=30, - clients=[Printer()], + listeners=[Printer()], functions=[], max_call_response_tokens=4096, debug=False, @@ -31,7 +31,7 @@ def __init__( else: self._logger = None - self._clients = clients + self._clients = listeners self._functions = {} for f in functions: @@ -44,10 +44,10 @@ def __init__( self._stream = stream self._check_model() - self._broadcast("begin_dialog", instructions) + self._broadcast("on_begin_dialog", instructions) def close(self): - self._broadcast("end_dialog") + self._broadcast("on_end_dialog") def query(self, prompt: str, user_text): """ @@ -73,7 +73,7 @@ def _check_model(self): _, provider, _, _ = litellm.get_llm_provider(self._model) if provider == "openai": self._broadcast( - "fail", + "on_fail", textwrap.dedent( f"""\ You need an OpenAI key to use the {self._model} model. @@ -84,7 +84,7 @@ def _check_model(self): sys.exit(1) else: self._broadcast( - "fail", + "on_fail", textwrap.dedent( f"""\ You need to set the following environment variables @@ -95,7 +95,7 @@ def _check_model(self): if not litellm.supports_function_calling(self._model): self._broadcast( - "fail", + "on_fail", textwrap.dedent( f"""\ The {self._model} model does not support function calls. @@ -188,7 +188,7 @@ def _batch_query(self, prompt: str, user_text): self._broadcast("on_end_query", stats) return stats except openai.OpenAIError as e: - self._broadcast("fail", f"Internal Error: {e.__dict__}") + self._broadcast("on_fail", f"Internal Error: {e.__dict__}") sys.exit(1) def _streamed_query(self, prompt: str, user_text): @@ -207,9 +207,11 @@ def _streamed_query(self, prompt: str, user_text): stream = self._completion(stream=True) - # litellm is broken for new GPT models that have content before calls, so... + # litellm.stream_chunk_builder is broken for new GPT models + # that have content before calls, so... - # stream the response, collecting the tool_call parts separately from the content + # stream the response, collecting the tool_call parts separately + # from the content self._broadcast("on_begin_stream") chunks = [] tool_chunks = [] @@ -223,7 +225,7 @@ def _streamed_query(self, prompt: str, user_text): tool_chunks.append(chunk) self._broadcast("on_end_stream") - # compute for the part that litellm gives back. + # then compute for the part that litellm gives back. completion = litellm.stream_chunk_builder( chunks, messages=self._conversation ) @@ -268,7 +270,7 @@ def _streamed_query(self, prompt: str, user_text): self._broadcast("on_end_query", stats) return stats except openai.OpenAIError as e: - self._broadcast("fail", f"Internal Error: {e.__dict__}") + self._broadcast("on_fail", f"Internal Error: {e.__dict__}") sys.exit(1) def _completion(self, stream=False): @@ -301,5 +303,6 @@ def _add_function_results_to_conversation(self, response_message): } self._conversation.append(response) except Exception as e: - # Warning: potential infinite loop. - self._broadcast("warn", f"Error processing tool calls: {e}") + # Warning: potential infinite loop if the LLM keeps sending + # the same bad call. + self._broadcast("on_warn", f"Error processing tool calls: {e}") diff --git a/src/chatdbg/assistant/listeners.py b/src/chatdbg/assistant/listeners.py index 95ae435..4f70754 100644 --- a/src/chatdbg/assistant/listeners.py +++ b/src/chatdbg/assistant/listeners.py @@ -42,10 +42,10 @@ def on_end_stream(self): # Notifications of non-fatal / fatal problems - def warn(self, text): + def on_warn(self, text): pass - def fail(self, text): + def on_fail(self, text): pass @@ -53,10 +53,10 @@ class Printer(BaseAssistantListener): def __init__(self, out=sys.stdout): self.out = out - def warn(self, text): + def on_warn(self, text): print(textwrap.indent(text, "*** "), file=self.out) - def fail(self, text): + def on_fail(self, text): print(textwrap.indent(text, "*** "), file=self.out) sys.exit(1) diff --git a/src/chatdbg/assistant/test.py b/src/chatdbg/assistant/test.py index 984c2ec..db79c13 100644 --- a/src/chatdbg/assistant/test.py +++ b/src/chatdbg/assistant/test.py @@ -7,7 +7,7 @@ class AssistantTest: def __init__(self): self.a = Assistant( "You generate text.", - clients=[StreamingPrinter()], + listeners=[StreamingPrinter()], functions=[self.weather], stream=True, ) diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 901211c..e80e536 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -15,16 +15,16 @@ import IPython from traitlets import TraitError -from chatdbg.ipdb_util.capture import CaptureInput, CaptureOutput +from chatdbg.util.capture import CaptureInput, CaptureOutput from .assistant.assistant import Assistant from .assistant.listeners import BaseAssistantListener -from .ipdb_util.chatlog import ChatDBGLog -from .ipdb_util.config import chatdbg_config -from .ipdb_util.locals import extract_locals -from .ipdb_util.prompts import pdb_instructions -from .ipdb_util.streamwrap import StreamingTextWrapper -from .ipdb_util.text import ( +from .util.chatlog import ChatDBGLog +from .util.config import chatdbg_config +from .util.locals import extract_locals +from .util.prompts import pdb_instructions +from .util.streamwrap import StreamingTextWrapper +from .util.text import ( format_limited, strip_color, truncate_proportionally, @@ -35,7 +35,7 @@ def load_ipython_extension(ipython): global chatdbg_config from chatdbg.chatdbg_pdb import ChatDBG - from chatdbg.ipdb_util.config import ChatDBGConfig, chatdbg_config + from chatdbg.util.config import ChatDBGConfig, chatdbg_config ipython.InteractiveTB.debugger_cls = ChatDBG chatdbg_config = ChatDBGConfig(config=ipython.config) @@ -635,7 +635,7 @@ def _make_assistant(self): debug=chatdbg_config.debug, functions=functions, stream=chatdbg_config.stream, - clients=[ + listeners=[ ChatAssistantClient( self.stdout, self.prompt, @@ -745,12 +745,11 @@ def _print(self, text, **kwargs): **kwargs, ) - def warn(self, text): + def on_warn(self, text): self._print(textwrap.indent(text, "*** ")) - def fail(self, text): + def on_fail(self, text): self._print(textwrap.indent(text, "*** ")) - sys.exit(1) def on_begin_stream(self): self._stream_wrapper = StreamingTextWrapper(self.chat_prefix, width=80) diff --git a/src/chatdbg/ipdb_util/capture.py b/src/chatdbg/ipdb_util/capture.py deleted file mode 100644 index c4fcf53..0000000 --- a/src/chatdbg/ipdb_util/capture.py +++ /dev/null @@ -1,56 +0,0 @@ -from io import StringIO, TextIOWrapper - - -class CaptureInput: - def __init__(self, input_stream): - input_stream = TextIOWrapper(input_stream.buffer, encoding="utf-8", newline="") - - self.original_input = input_stream - self.capture_buffer = StringIO() - self.original_readline = input_stream.buffer.raw.readline - - def custom_readline(*args, **kwargs): - input_data = self.original_readline(*args, **kwargs) - self.capture_buffer.write(input_data.decode()) - return input_data - - input_stream.buffer.raw.readline = custom_readline - - def readline(self, *args, **kwargs): - input_data = self.original_input.readline(*args, **kwargs) - self.capture_buffer.write(input_data) - self.capture_buffer.flush() - return input_data - - def read(self, *args, **kwargs): - input_data = self.original_input.read(*args, **kwargs) - self.capture_buffer.write(input_data) - self.capture_buffer.flush() - return input_data - - def get_captured_input(self): - return self.capture_buffer.getvalue() - - -class CaptureOutput: - """ - File wrapper that will stash a copy of everything written. - """ - - def __init__(self, file): - self.file = file - self.buffer = StringIO() - - def write(self, data): - self.buffer.write(data) - return self.file.write(data) - - def getvalue(self): - return self.buffer.getvalue() - - def getfile(self): - return self.file - - def __getattr__(self, attr): - # Delegate attribute access to the file object - return getattr(self.file, attr) diff --git a/src/chatdbg/ipdb_util/chatlog.py b/src/chatdbg/ipdb_util/chatlog.py deleted file mode 100644 index 209cc59..0000000 --- a/src/chatdbg/ipdb_util/chatlog.py +++ /dev/null @@ -1,149 +0,0 @@ -import atexit -from io import StringIO -from datetime import datetime -import uuid -import sys -import yaml -from pydantic import BaseModel -from typing import List, Union, Optional - -from chatdbg.ipdb_util.capture import CaptureOutput -from ..assistant.listeners import BaseAssistantListener - - -class ChatDBGLog(BaseAssistantListener): - - def __init__(self, log_filename, config, capture_streams=True): - self._log_filename = log_filename - self.config = config - if capture_streams: - self._stdout_wrapper = CaptureOutput(sys.stdout) - self._stderr_wrapper = CaptureOutput(sys.stderr) - sys.stdout = self._stdout_wrapper - sys.stderr = self._stdout_wrapper - else: - self._stderr_wrapper = None - self._stderr_wrapper = None - - self._log = self._make_log() - self._current_chat = None - - def _make_log(self): - meta = { - "time": datetime.now(), - "command_line": " ".join(sys.argv), - "uid": str(uuid.uuid4()), - "config": self.config, - } - return { - "steps": [], - "meta": meta, - "instructions": None, - "stdout": self._stdout_wrapper.getvalue(), - "stderr": self._stderr_wrapper.getvalue(), - } - - def _dump(self): - log = self._log - - def total(key): - return sum( - x["stats"][key] - for x in log["steps"] - if x["output"]["type"] == "chat" and "stats" in x["output"] - ) - - log["meta"]["total_tokens"] = total("tokens") - log["meta"]["total_time"] = total("time") - log["meta"]["total_cost"] = total("cost") - - print(f"*** Writing ChatDBG dialog log to {self._log_filename}") - - with open(self._log_filename, "a") as file: - - def literal_presenter(dumper, data): - if "\n" in data: - return dumper.represent_scalar( - "tag:yaml.org,2002:str", data, style="|" - ) - else: - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - yaml.add_representer(str, literal_presenter) - yaml.dump([log], file, default_flow_style=False, indent=2) - - def on_begin_dialog(self, instructions): - log = self._log - assert log != None - log["instructions"] = instructions - - def on_end_dialog(self): - if self._log != None: - self._dump() - self._log = self._make_log() - - def on_begin_query(self, prompt, extra): - log = self._log - assert log != None - assert self._current_chat == None - self._current_chat = { - "input": extra, - "prompt": prompt, - "output": {"type": "chat", "outputs": []}, - } - - def on_end_query(self, stats): - log = self._log - assert log != None - assert self._current_chat != None - log["steps"] += [self._current_chat] - log["stats"] = stats - self._current_chat = None - - def _post(self, text, kind): - log = self._log - assert log != None - if self._current_chat != None: - self._current_chat["output"]["outputs"].append( - {"type": "text", "output": f"*** {kind}: {text}"} - ) - else: - log["steps"].append( - { - "type": "call", - "input": f"*** {kind}", - "output": {"type": "text", "output": text}, - } - ) - - def warn(self, text): - self._post(text, "Warning") - - def fail(self, text): - self._post(text, "Failure") - - def on_response(self, text): - log = self._log - assert log != None - assert self._current_chat != None - self._current_chat["output"]["outputs"].append({"type": "text", "output": text}) - - def on_function_call(self, call, result): - log = self._log - assert log != None - if self._current_chat != None: - self._current_chat["output"]["outputs"].append( - { - "type": "call", - "input": call, - "output": {"type": "text", "output": result}, - } - ) - else: - log["steps"].append( - { - "type": "call", - "input": call, - "output": {"type": "text", "output": result}, - } - ) diff --git a/src/chatdbg/ipdb_util/config.py b/src/chatdbg/ipdb_util/config.py deleted file mode 100644 index df08178..0000000 --- a/src/chatdbg/ipdb_util/config.py +++ /dev/null @@ -1,81 +0,0 @@ -import os - -from traitlets import Bool, Int, TraitError, Unicode -from traitlets.config import Configurable - - -def _chatdbg_get_env(option_name, default_value): - env_name = "CHATDBG_" + option_name.upper() - v = os.getenv(env_name, str(default_value)) - if type(default_value) == int: - return int(v) - elif type(default_value) == bool: - return v.lower() == "true" - else: - return v - - -class ChatDBGConfig(Configurable): - model = Unicode( - _chatdbg_get_env("model", "gpt-4-1106-preview"), help="The LLM model" - ).tag(config=True) - - debug = Bool(_chatdbg_get_env("debug", False), help="Log LLM calls").tag( - config=True - ) - - log = Unicode(_chatdbg_get_env("log", "log.yaml"), help="The log file").tag( - config=True - ) - - tag = Unicode(_chatdbg_get_env("tag", ""), help="Any extra info for log file").tag( - config=True - ) - rc_lines = Unicode( - _chatdbg_get_env("rc_lines", "[]"), help="lines to run at startup" - ).tag(config=True) - - context = Int( - _chatdbg_get_env("context", 10), - help="lines of source code to show when displaying stacktrace information", - ).tag(config=True) - - show_locals = Bool( - _chatdbg_get_env("show_locals", True), - help="show local var values in stacktrace", - ).tag(config=True) - - show_libs = Bool( - _chatdbg_get_env("show_libs", False), help="show library frames in stacktrace" - ).tag(config=True) - - show_slices = Bool( - _chatdbg_get_env("show_slices", True), help="support the `slice` command" - ).tag(config=True) - - take_the_wheel = Bool( - _chatdbg_get_env("take_the_wheel", True), help="Let LLM take the wheel" - ).tag(config=True) - - stream = Bool( - _chatdbg_get_env("stream", False), help="Stream the response at it arrives" - ).tag(config=True) - - def to_json(self): - """Serialize the object to a JSON string.""" - return { - "model": self.model, - "debug": self.debug, - "log": self.log, - "tag": self.tag, - "rc_lines": self.rc_lines, - "context": self.context, - "show_locals": self.show_locals, - "show_libs": self.show_libs, - "show_slices": self.show_slices, - "take_the_wheel": self.take_the_wheel, - "stream": self.stream, - } - - -chatdbg_config: ChatDBGConfig = ChatDBGConfig() diff --git a/src/chatdbg/ipdb_util/locals.py b/src/chatdbg/ipdb_util/locals.py deleted file mode 100644 index c2a5aeb..0000000 --- a/src/chatdbg/ipdb_util/locals.py +++ /dev/null @@ -1,55 +0,0 @@ -import ast -import inspect -import textwrap - - -class SymbolFinder(ast.NodeVisitor): - def __init__(self): - self.defined_symbols = set() - - def visit_Assign(self, node): - for target in node.targets: - if isinstance(target, ast.Name): - self.defined_symbols.add(target.id) - self.generic_visit(node) - - def visit_For(self, node): - if isinstance(node.target, ast.Name): - self.defined_symbols.add(node.target.id) - self.generic_visit(node) - - def visit_comprehension(self, node): - if isinstance(node.target, ast.Name): - self.defined_symbols.add(node.target.id) - self.generic_visit(node) - - -def extract_locals(frame): - try: - source = textwrap.dedent(inspect.getsource(frame)) - tree = ast.parse(source) - - finder = SymbolFinder() - finder.visit(tree) - - args, varargs, keywords, locals = inspect.getargvalues(frame) - parameter_symbols = set(args + [varargs, keywords]) - parameter_symbols.discard(None) - - return (finder.defined_symbols | parameter_symbols) & locals.keys() - except: - # ipes - return set() - - -def extract_nb_globals(globals): - result = set() - for source in globals["In"]: - try: - tree = ast.parse(source) - finder = SymbolFinder() - finder.visit(tree) - result = result | (finder.defined_symbols & globals.keys()) - except Exception as e: - pass - return result diff --git a/src/chatdbg/ipdb_util/prompts.py b/src/chatdbg/ipdb_util/prompts.py deleted file mode 100644 index f2c72bd..0000000 --- a/src/chatdbg/ipdb_util/prompts.py +++ /dev/null @@ -1,83 +0,0 @@ -import os - -_intro = f"""\ -You are a debugging assistant. You will be given a Python stack trace for an -error and answer questions related to the root cause of the error. -""" - -_pdb_function = f"""\ -Call the `pdb` function to run Pdb debugger commands on the stopped program. You -may call the `pdb` function to run the following commands: `bt`, `up`, `down`, -`p expression`, `list`. - -Call `pdb` to print any variable value or expression that you believe may -contribute to the error. -""" - - -_info_function = """\ -Call the `info` function to get the documentation and source code for any -variable, function, package, class, method reference, field reference, or -dotted reference visible in the current frame. Examples include: n, e.n -where e is an expression, and t.n where t is a type. - -Unless it is from a common, widely-used library, you MUST call `info` exactly once on any -symbol that is referenced in code leading up to the error. -""" - - -_slice_function = """\ -Call the `slice` function to get the code used to produce -the value currently stored a variable. You MUST call `slice` exactly once on any -variable used but not defined in the current frame's code. -""" - -_take_the_wheel_instructions = """\ -Call the provided functions as many times as you would like. -""" - -_general_instructions = f"""\ -The root cause of any error is likely due to a problem in the source code from the user. - -Explain why each variable contributing to the error has been set -to the value that it has. - -Continue with your explanations until you reach the root cause of the error. Your answer may be as long as necessary. - -End your answer with a section titled "##### Recommendation\\n" that contains one of: -* a fix if you have identified the root cause -* a numbered list of 1-3 suggestions for how to continue debugging if you have not -""" - - -_wheel_and_slice = f"""\ -{_intro} -{_pdb_function} -{_info_function} -{_slice_function} -{_take_the_wheel_instructions} -{_general_instructions} -""" - -_wheel_no_slice = f"""\ -{_intro} -{_pdb_function} -{_info_function} -{_take_the_wheel_instructions} -{_general_instructions} -""" - -_no_wheel = f"""\ -{_intro} -{_general_instructions} -""" - - -def pdb_instructions(supports_flow, take_the_wheel): - if take_the_wheel: - if supports_flow: - return _wheel_and_slice - else: - return _wheel_no_slice - else: - return _no_wheel diff --git a/src/chatdbg/ipdb_util/streamwrap.py b/src/chatdbg/ipdb_util/streamwrap.py deleted file mode 100644 index 0f95e08..0000000 --- a/src/chatdbg/ipdb_util/streamwrap.py +++ /dev/null @@ -1,48 +0,0 @@ -import textwrap -import re -import sys -from .text import word_wrap_except_code_blocks - - -class StreamingTextWrapper: - - def __init__(self, indent=" ", width=80): - self._buffer = "" # the raw text so far - self._wrapped = "" # the successfully wrapped text do far - self._pending = ( - "" # the part after the last space in buffer -- has not been wrapped yet - ) - self._indent = indent - self._width = width - len(indent) - - def append(self, text, flush=False): - if flush: - self._buffer += self._pending + text - self._pending = "" - else: - text_bits = re.split(r"(\s+)", self._pending + text) - self._pending = text_bits[-1] - self._buffer += "".join(text_bits[0:-1]) - - wrapped = word_wrap_except_code_blocks(self._buffer) - wrapped = textwrap.indent(wrapped, self._indent, lambda _: True) - wrapped_delta = wrapped[len(self._wrapped) :] - self._wrapped = wrapped - return wrapped_delta - - def flush(self): - if len(self._buffer) > 0: - result = self.append("\n", flush=True) - else: - result = self.append("", flush=True) - self._buffer = "" - self._wrapped = "" - return result - - -if __name__ == "__main__": - s = StreamingTextWrapper(3, 20) - for x in sys.argv[1:]: - y = s.append(" " + x) - print(y, end="", flush=True) - print(s.flush()) diff --git a/src/chatdbg/ipdb_util/text.py b/src/chatdbg/ipdb_util/text.py deleted file mode 100644 index 9d31e0d..0000000 --- a/src/chatdbg/ipdb_util/text.py +++ /dev/null @@ -1,144 +0,0 @@ -import re -import itertools -import inspect -import numbers -import numpy as np -import textwrap - - -def make_arrow(pad): - """generate the leading arrow in front of traceback or debugger""" - if pad >= 2: - return "-" * (pad - 2) + "> " - elif pad == 1: - return ">" - return "" - - -def strip_color(s): - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", s) - - -def _is_iterable(obj): - try: - iter(obj) - return True - except TypeError: - return False - - -def _repr_if_defined(obj): - if obj.__class__ in [np.ndarray, dict, list, tuple]: - # handle these at iterables to truncate reasonably - return False - result = ( - "__repr__" in dir(obj.__class__) - and obj.__class__.__repr__ is not object.__repr__ - ) - return result - - -def format_limited(value, limit=10, depth=3): - def format_tuple(t, depth): - return tuple([helper(x, depth) for x in t]) - - def format_list(list, depth): - return [helper(x, depth) for x in list] - - def format_dict(items, depth): - return {k: helper(v, depth) for k, v in items} - - def format_object(obj, depth): - attributes = dir(obj) - fields = { - attr: getattr(obj, attr, None) - for attr in attributes - if not callable(getattr(obj, attr, None)) and not attr.startswith("__") - } - return format( - f"{type(obj).__name__} object with fields {format_dict(fields.items(), depth)}" - ) - - def helper(value, depth): - if depth == 0: - return ... - if value is Ellipsis: - return ... - if isinstance(value, dict): - if len(value) > limit: - return format_dict( - list(value.items())[: limit - 1] + [(..., ...)], depth - 1 - ) - else: - return format_dict(value.items(), depth - 1) - elif isinstance(value, (str, bytes)): - if len(value) > 254: - value = str(value)[0:253] + "..." - return value - elif isinstance(value, tuple): - if len(value) > limit: - return format_tuple(value[0 : limit - 1] + (...,), depth - 1) - else: - return format_tuple(value, depth - 1) - elif value is None or isinstance( - value, (int, float, bool, type, numbers.Number) - ): - return value - elif isinstance(value, np.ndarray): - with np.printoptions(threshold=limit): - return np.array_repr(value) - elif inspect.isclass(type(value)) and _repr_if_defined(value): - return repr(value) - elif _is_iterable(value): - value = list(itertools.islice(value, 0, limit + 1)) - if len(value) > limit: - return format_list(value[: limit - 1] + [...], depth - 1) - else: - return format_list(value, depth - 1) - elif inspect.isclass(type(value)): - return format_object(value, depth - 1) - else: - return value - - result = str(helper(value, depth=3)).replace("Ellipsis", "...") - if len(result) > 1024 * 2: - result = result[: 1024 * 2 - 3] + "..." - if type(value) == str: - return "'" + result + "'" - else: - return result - - -def truncate_proportionally(text, maxlen=32000, top_proportion=0.5): - """Omit part of a string if needed to make it fit in a maximum length.""" - if len(text) > maxlen: - pre = max(0, int((maxlen - 3) * top_proportion)) - post = max(0, maxlen - 3 - pre) - return text[:pre] + "..." + text[len(text) - post :] - return text - - -def word_wrap_except_code_blocks(text: str, width: int = 80) -> str: - """ - Wraps text except for code blocks for nice terminal formatting. - - Splits the text into paragraphs and wraps each paragraph, - except for paragraphs that are inside of code blocks denoted - by ` ``` `. Returns the updated text. - - Args: - text (str): The text to wrap. - width (int): The width of the lines to wrap at, passed to `textwrap.fill`. - - Returns: - The wrapped text. - """ - blocks = text.split("```") - for i in range(len(blocks)): - if i % 2 == 0: - paras = blocks[i].split("\n") - wrapped = [textwrap.fill(para, width=width) for para in paras] - blocks[i] = "\n".join(wrapped) - - return "```".join(blocks)