diff --git a/coml/core.py b/coml/core.py
index 2bc2331..9049884 100644
--- a/coml/core.py
+++ b/coml/core.py
@@ -31,8 +31,12 @@
)
from .vis_utils import VisVerifier
+_debug_mode: bool = False
+
def debug_messages(*messages: BaseMessage) -> None:
+ if not _debug_mode:
+ return
for message in messages:
if isinstance(message, SystemMessage):
print(colorama.Fore.BLUE + message.content + colorama.Fore.RESET + "\n")
@@ -113,10 +117,10 @@ def generate_code(
)
question, _ = render_generate_context(context)
messages.append(HumanMessage(content=question))
- # debug_messages(*messages)
+ debug_messages(*messages)
response = self.llm(messages)
- # debug_messages(response)
+ debug_messages(response)
code = parse_code(response.content)
return {**context, "answer": code}
diff --git a/coml/magics.py b/coml/magics.py
index 4ec4287..8c175c8 100644
--- a/coml/magics.py
+++ b/coml/magics.py
@@ -25,6 +25,7 @@
run_code_in_next_cell,
update_running_cell_metadata,
)
+from .linter import lint
from .prompt_utils import (
FixContext,
GenerateContext,
@@ -33,6 +34,38 @@
filter_variables,
)
+VERIFY_STYLE = """
+
+"""
+
@magics_class
class CoMLMagics(Magics):
@@ -116,6 +149,22 @@ def coml(self, line, cell=None):
)
return self._post_generation(generate_context["answer"], generate_context)
+ @no_var_expand
+ @line_magic
+ def comlset(self, line):
+ key, value = line.lower().strip().split()
+ if key == "debug":
+ from . import core
+
+ if value == "on":
+ core._debug_mode = True
+ elif value == "off":
+ core._debug_mode = False
+ else:
+ raise ValueError("Debug mode must be either on or off.")
+ else:
+ raise ValueError("Unknown setting.")
+
@no_var_expand
@line_magic
def comlinspire(self, line):
@@ -202,47 +251,11 @@ def comlverify(self, line):
generated_vis = False
if context.get("action") == "run":
error, output = parse_cell_outputs(target_cell["outputs"])
- generated_vis = (
- output
- and "" in output
- and "request" in context
- and "answer" in context
- and "codes" in context
- and "variables" in context
- )
- style = """"""
+ generated_vis = output and "" in output
def display_statuses(statuses):
clear_output(wait=True)
- html = style + "\n"
+ html = VERIFY_STYLE + "\n"
display_names = {
"lint": "PyLint",
"rubberduck": "Rubberduck",
@@ -261,21 +274,33 @@ def display_statuses(statuses):
False: "❌",
}
loading = ""
+ message_template = "{}: {}
\n{} "
for name in display_names:
detail_message = "Still loading..."
if name in statuses:
detail_message = markdown.markdown(
statuses[name]["details"], extensions=["nl2br"]
)
- html += f"""
-{display_names[name]}: {loading if name not in statuses else status_icon[statuses[name]["result"]]}
-{detail_message}
- \n"""
+ html += message_template.format(
+ display_names[name],
+ loading
+ if name not in statuses
+ else status_icon[statuses[name]["result"]],
+ detail_message,
+ )
+
display(HTML(html))
result = {}
display_statuses(result)
+ lint_result, lint_details = lint("\n".join(self._get_code_context()), code)
+ result["lint"] = {
+ "result": lint_result,
+ "details": lint_details,
+ }
+ display_statuses(result)
+
rubberduck_result, rubberduck_details = self.agent.static_check(code, context)
result["rubberduck"] = {
"result": rubberduck_result,
@@ -287,18 +312,20 @@ def display_statuses(statuses):
# verify generated vis
svg_string = output.replace("", "")
request = context["request"]
- new_code = context["answer"]
- previous_code = "\n".join(context["codes"])
variables = context["variables"]
# Roughly judge the source of the visualization
- if "plt.show()" in new_code:
- source = "matplotlib"
+ if "plt.show()" in code:
+ vis_framework = "matplotlib"
(
visualization_check_result,
visualization_check_details,
) = self.agent.visualization_check(
- request, previous_code, svg_string, variables, source
+ context["request"],
+ "\n".join(self._get_code_context()),
+ code,
+ context["variables"],
+ vis_framework,
)
details = ""
for detail in visualization_check_details:
@@ -308,6 +335,7 @@ def display_statuses(statuses):
"details": details,
}
display_statuses(result)
+
elif error or output:
sanity_result, sanity_details = self.agent.output_sanity_check(
code, context, error, output
diff --git a/package.json b/package.json
index 28eb15d..b7601ec 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "coml",
- "version": "0.0.2",
+ "version": "0.0.3",
"description": "JupyterLab extension for CoML.",
"keywords": [
"jupyter",
diff --git a/pyproject.toml b/pyproject.toml
index 26cdd23..c5d2ab1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "mlcopilot"
-version = "0.0.2"
+version = "0.0.3"
dependencies = [
"click",
"colorama",