From 2febf194cd0c5332fb40f0a51f457caa4c17cb1d Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 2 Nov 2023 14:20:07 +0800 Subject: [PATCH] Bump to 0.0.3 (#11) --- coml/core.py | 8 +++- coml/magics.py | 122 ++++++++++++++++++++++++++++++------------------- package.json | 2 +- pyproject.toml | 2 +- 4 files changed, 83 insertions(+), 51 deletions(-) diff --git a/coml/core.py b/coml/core.py index 2bc2331..9049884 100644 --- a/coml/core.py +++ b/coml/core.py @@ -31,8 +31,12 @@ ) from .vis_utils import VisVerifier +_debug_mode: bool = False + def debug_messages(*messages: BaseMessage) -> None: + if not _debug_mode: + return for message in messages: if isinstance(message, SystemMessage): print(colorama.Fore.BLUE + message.content + colorama.Fore.RESET + "\n") @@ -113,10 +117,10 @@ def generate_code( ) question, _ = render_generate_context(context) messages.append(HumanMessage(content=question)) - # debug_messages(*messages) + debug_messages(*messages) response = self.llm(messages) - # debug_messages(response) + debug_messages(response) code = parse_code(response.content) return {**context, "answer": code} diff --git a/coml/magics.py b/coml/magics.py index 4ec4287..8c175c8 100644 --- a/coml/magics.py +++ b/coml/magics.py @@ -25,6 +25,7 @@ run_code_in_next_cell, update_running_cell_metadata, ) +from .linter import lint from .prompt_utils import ( FixContext, GenerateContext, @@ -33,6 +34,38 @@ filter_variables, ) +VERIFY_STYLE = """ + +""" + @magics_class class CoMLMagics(Magics): @@ -116,6 +149,22 @@ def coml(self, line, cell=None): ) return self._post_generation(generate_context["answer"], generate_context) + @no_var_expand + @line_magic + def comlset(self, line): + key, value = line.lower().strip().split() + if key == "debug": + from . import core + + if value == "on": + core._debug_mode = True + elif value == "off": + core._debug_mode = False + else: + raise ValueError("Debug mode must be either on or off.") + else: + raise ValueError("Unknown setting.") + @no_var_expand @line_magic def comlinspire(self, line): @@ -202,47 +251,11 @@ def comlverify(self, line): generated_vis = False if context.get("action") == "run": error, output = parse_cell_outputs(target_cell["outputs"]) - generated_vis = ( - output - and "" in output - and "request" in context - and "answer" in context - and "codes" in context - and "variables" in context - ) - style = """""" + generated_vis = output and "" in output def display_statuses(statuses): clear_output(wait=True) - html = style + "\n" + html = VERIFY_STYLE + "\n" display_names = { "lint": "PyLint", "rubberduck": "Rubberduck", @@ -261,21 +274,33 @@ def display_statuses(statuses): False: "❌", } loading = "" + message_template = "
{}: {}\n{}
" for name in display_names: detail_message = "Still loading..." if name in statuses: detail_message = markdown.markdown( statuses[name]["details"], extensions=["nl2br"] ) - html += f"""
-{display_names[name]}: {loading if name not in statuses else status_icon[statuses[name]["result"]]} -{detail_message} -
\n""" + html += message_template.format( + display_names[name], + loading + if name not in statuses + else status_icon[statuses[name]["result"]], + detail_message, + ) + display(HTML(html)) result = {} display_statuses(result) + lint_result, lint_details = lint("\n".join(self._get_code_context()), code) + result["lint"] = { + "result": lint_result, + "details": lint_details, + } + display_statuses(result) + rubberduck_result, rubberduck_details = self.agent.static_check(code, context) result["rubberduck"] = { "result": rubberduck_result, @@ -287,18 +312,20 @@ def display_statuses(statuses): # verify generated vis svg_string = output.replace("", "") request = context["request"] - new_code = context["answer"] - previous_code = "\n".join(context["codes"]) variables = context["variables"] # Roughly judge the source of the visualization - if "plt.show()" in new_code: - source = "matplotlib" + if "plt.show()" in code: + vis_framework = "matplotlib" ( visualization_check_result, visualization_check_details, ) = self.agent.visualization_check( - request, previous_code, svg_string, variables, source + context["request"], + "\n".join(self._get_code_context()), + code, + context["variables"], + vis_framework, ) details = "" for detail in visualization_check_details: @@ -308,6 +335,7 @@ def display_statuses(statuses): "details": details, } display_statuses(result) + elif error or output: sanity_result, sanity_details = self.agent.output_sanity_check( code, context, error, output diff --git a/package.json b/package.json index 28eb15d..b7601ec 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "coml", - "version": "0.0.2", + "version": "0.0.3", "description": "JupyterLab extension for CoML.", "keywords": [ "jupyter", diff --git a/pyproject.toml b/pyproject.toml index 26cdd23..c5d2ab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mlcopilot" -version = "0.0.2" +version = "0.0.3" dependencies = [ "click", "colorama",