From e4db1572c5bb8f87b6fc89280de0e99e2c6630f5 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Fri, 1 Nov 2024 18:00:53 -0700 Subject: [PATCH 1/4] feat: add optimizer in the UI --- README.md | 14 +- docetl/builder.py | 34 +- docetl/console.py | 46 +- docetl/operations/base.py | 2 + .../map_optimizer/config_generators.py | 5 +- docetl/optimizers/map_optimizer/evaluator.py | 29 +- docetl/optimizers/map_optimizer/optimizer.py | 12 + .../map_optimizer/plan_generators.py | 3 - .../map_optimizer/prompt_generators.py | 30 +- docetl/optimizers/reduce_optimizer.py | 12 +- docetl/runner.py | 11 +- docetl/utils.py | 18 +- server/app/routes/pipeline.py | 71 +- website/package-lock.json | 27 +- website/package.json | 1 + website/src/app/api/utils.ts | 7 + website/src/app/localStorageKeys.ts | 1 + website/src/app/types.ts | 3 +- website/src/components/AnsiRenderer.tsx | 57 +- website/src/components/OperationCard.tsx | 123 ++- website/src/components/Output.tsx | 67 +- website/src/components/PipelineGui.tsx | 701 ++++++++++-------- website/src/components/operations/args.tsx | 148 +++- .../src/components/operations/components.tsx | 38 +- website/src/components/ui/progress.tsx | 28 + website/src/contexts/PipelineContext.tsx | 37 + 26 files changed, 1047 insertions(+), 478 deletions(-) create mode 100644 website/src/components/ui/progress.tsx diff --git a/README.md b/README.md index ae3edd43..34dd3cbb 100644 --- a/README.md +++ b/README.md @@ -33,17 +33,25 @@ DocETL is the ideal choice when you're looking to maximize correctness and outpu ## Installation -See the documentation for installing from PyPI. +You can install DocETL using either PyPI or from source. We recommend installing from source for the latest features and bug fixes. ### Prerequisites Before installing DocETL, ensure you have Python 3.10 or later installed on your system. You can check your Python version by running: +```bash python --version +``` + +### Install from PyPI + +```bash +pip install docetl +``` -### Installation Steps (from Source) +### Install from Source -1. Clone the DocETL repository: +1. Clone the DocETL repository (or your fork): ```bash git clone https://github.com/ucbepic/docetl.git diff --git a/docetl/builder.py b/docetl/builder.py index 4d3c8d30..58eb3d00 100644 --- a/docetl/builder.py +++ b/docetl/builder.py @@ -83,7 +83,6 @@ class Optimizer: def __init__( self, runner: "DSLRunner", - max_threads: Optional[int] = None, model: str = "gpt-4o", resume: bool = False, timeout: int = 60, @@ -980,6 +979,9 @@ def _get_sample_data( return self._get_reduce_sample( data, op_config.get("reduce_key"), sample_size ) + + if not self.config.get("optimizer_config", {}).get("random_sample", False): + return data[:sample_size] # Take the random 500 examples or all if less than 500 initial_data = random.sample(data, min(500, len(data))) @@ -1038,7 +1040,13 @@ def _get_reduce_sample( group_sample_size = int(sample_size * group_proportion) # Sample from the group - group_sample = random.sample(items, min(group_sample_size, len(items))) + if not self.config.get("optimizer_config", {}).get("random_sample", False): + group_sample = items[:group_sample_size] + else: + group_sample = random.sample( + items, min(group_sample_size, len(items)) + ) + sample.extend(group_sample) # If we haven't reached the desired sample size, add more items randomly @@ -1051,22 +1059,10 @@ def _get_reduce_sample( ] additional_sample = random.sample( remaining_items, - min(sample_size - len(sample), len(remaining_items)), - ) - sample.extend(additional_sample) - - # Add items randomly from non-top groups to meet the sample size - if len(sample) < sample_size: - remaining_items = [ - item - for _, items in grouped_data.items() - for item in items - if item not in sample - ] - additional_sample = random.sample( - remaining_items, - min(sample_size - len(sample), len(remaining_items)), - ) + min( + sample_size - len(sample), len(remaining_items) + ), + ) if self.config.get("optimizer_config", {}).get("random_sample", False) else remaining_items[:sample_size - len(sample)] sample.extend(additional_sample) # Create a histogram of group sizes @@ -1201,7 +1197,7 @@ def _optimize_equijoin( if map_operation["optimize"]: dataset_to_transform_sample = random.sample( dataset_to_transform, self.sample_size_map.get("map") - ) + ) if self.config.get("optimizer_config", {}).get("random_sample", False) else dataset_to_transform[:self.sample_size_map.get("map")] optimized_map_operations = self._optimize_map( map_operation, dataset_to_transform_sample ) diff --git a/docetl/console.py b/docetl/console.py index 4a07f35d..da389f44 100644 --- a/docetl/console.py +++ b/docetl/console.py @@ -1,10 +1,11 @@ import os -from typing import Any, Optional +import time +from typing import Any, Optional, Tuple from rich.console import Console from io import StringIO import threading import queue - +from docetl.utils import StageType, get_stage_description class ThreadSafeConsole(Console): def __init__(self, *args, **kwargs): @@ -13,6 +14,47 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.input_event = threading.Event() self.input_value = None + self.optimizer_statuses = [] + self.optimizer_rationale = None + + def status( + self, + status: "RenderableType", + *, + spinner: str = "dots", + spinner_style: "StyleType" = "status.spinner", + speed: float = 1.0, + refresh_per_second: float = 12.5, + ) -> "Status": + from rich.status import Status + + status_renderable = Status( + status, + console=None, + spinner=spinner, + spinner_style=spinner_style, + speed=speed, + refresh_per_second=refresh_per_second, + ) + return status_renderable + + def post_optimizer_rationale(self, should_optimize: bool, rationale: str, validator_prompt: str): + self.optimizer_rationale = (should_optimize, rationale, validator_prompt) + + def post_optimizer_status(self, stage: StageType): + self.optimizer_statuses.append((stage, time.time())) + + def get_optimizer_progress(self) -> Tuple[str, float]: + if len(self.optimizer_statuses) == 0: + return ("Optimization starting...", 0) + + if len(self.optimizer_statuses) > 0 and self.optimizer_statuses[-1][0] == StageType.END: + return (get_stage_description(StageType.END), 1) + + num_stages = len(StageType) - 1 + num_completed = len([s for s in self.optimizer_statuses if s[1]]) - 1 + current_stage = self.optimizer_statuses[-1][0] + return (get_stage_description(current_stage), num_completed / num_stages) def print(self, *args, **kwargs): super().print(*args, **kwargs) diff --git a/docetl/operations/base.py b/docetl/operations/base.py index 88377077..4f85b402 100644 --- a/docetl/operations/base.py +++ b/docetl/operations/base.py @@ -38,6 +38,7 @@ def __init__( max_threads: int, console: Optional[Console] = None, status: Optional[Status] = None, + is_build: bool = False, **kwargs, ): """ @@ -62,6 +63,7 @@ def __init__( self.num_retries_on_validate_failure = self.config.get( "num_retries_on_validate_failure", 0 ) + self.is_build = is_build self.syntax_check() # This must be overridden in a subclass diff --git a/docetl/optimizers/map_optimizer/config_generators.py b/docetl/optimizers/map_optimizer/config_generators.py index f9188fa0..5d7008e0 100644 --- a/docetl/optimizers/map_optimizer/config_generators.py +++ b/docetl/optimizers/map_optimizer/config_generators.py @@ -252,9 +252,8 @@ def _check_metadata_necessity( Determine if metadata is needed to perform the subtask. Consider: - 1. Does the subtask require information that might be present in metadata? - 2. Is the sample chunk or full input missing any crucial information that could be in metadata? - 3. Would having metadata significantly improve the performance or accuracy of the subtask? + 1. Does the input sample have any structural metadata that might be relevant to the subtask? + 2. Is the sample chunk or full input missing any crucial information that could be in this metadata? Provide your response in the following format: """ diff --git a/docetl/optimizers/map_optimizer/evaluator.py b/docetl/optimizers/map_optimizer/evaluator.py index fc963676..fb22825d 100644 --- a/docetl/optimizers/map_optimizer/evaluator.py +++ b/docetl/optimizers/map_optimizer/evaluator.py @@ -267,7 +267,7 @@ def _assess_operation( # Extract input variables from the prompt variables_in_prompt = extract_jinja_variables(op_config["prompt"]) variables_in_prompt = [v.replace("input.", "") for v in variables_in_prompt] - input_sample = input_data[:2] + input_sample = input_data[:3] output_sample = [ next( ( @@ -291,7 +291,7 @@ def _assess_operation( ) available_tokens = ( model_input_context_length - prompt_tokens - 100 - ) // 4 # 100 token buffer, divide by 4 for each sample + ) // 6 # 100 token buffer, divide by 6 for each sample # Prepare and truncate sample data input_1 = truncate_sample_data( @@ -336,22 +336,43 @@ def _assess_operation( {json.dumps({"input": input_2, "output": output_2}, indent=2)} """ + if len(input_sample) > 2: + input_3 = truncate_sample_data( + {key: input_sample[2].get(key, "N/A") for key in variables_in_prompt}, + available_tokens, + [variables_in_prompt], + self.llm_client.model, + ) + output_3 = truncate_sample_data( + {key: output_sample[2].get(key, "N/A") for key in output_schema.keys()}, + available_tokens, + [list(output_schema.keys())], + self.llm_client.model, + ) + prompt += f""" + ---Pair 3--- + {json.dumps({"input": input_3, "output": output_3}, indent=2)} + """ + prompt += f""" Custom Validator Prompt: {validator_prompt} - Based on the above information, please assess the operation's performance. Provide your assessment in the following format: + Based on the above information, please assess the operation's performance. + If it needs improvement, provide specific examples in your assessment. + Be very detailed in your reasons for improvements, if any. + Provide your assessment in the following format: """ parameters = { "type": "object", "properties": { - "needs_improvement": {"type": "boolean"}, "reasons": {"type": "array", "items": {"type": "string"}}, "improvements": { "type": "array", "items": {"type": "string"}, }, + "needs_improvement": {"type": "boolean"}, }, "required": ["needs_improvement", "reasons", "improvements"], } diff --git a/docetl/optimizers/map_optimizer/optimizer.py b/docetl/optimizers/map_optimizer/optimizer.py index 034f74ac..d27872e9 100644 --- a/docetl/optimizers/map_optimizer/optimizer.py +++ b/docetl/optimizers/map_optimizer/optimizer.py @@ -133,6 +133,7 @@ def optimize( The cost is the cost of the optimizer (from possibly synthesizing resolves). """ + self.console.post_optimizer_status(StageType.SAMPLE_RUN) input_data = copy.deepcopy(input_data) # Add id to each input_data for i in range(len(input_data)): @@ -184,7 +185,9 @@ def optimize( }, ) + # Generate custom validator prompt + self.console.post_optimizer_status(StageType.SHOULD_OPTIMIZE) validator_prompt = self.prompt_generator._generate_validator_prompt( op_config, input_data, output_data ) @@ -218,6 +221,11 @@ def optimize( "improvements": assessment.get("improvements", []), }, ) + self.console.post_optimizer_rationale( + assessment.get("needs_improvement", True), + "\n".join(assessment.get("reasons", [])), + validator_prompt + ) # Check if improvement is needed based on the assessment if not data_exceeds_limit and not assessment.get("needs_improvement", True): @@ -237,6 +245,7 @@ def optimize( candidate_plans["no_change"] = [op_config] # Generate chunk size plans + self.console.post_optimizer_status(StageType.CANDIDATE_PLANS) self.console.log("[bold magenta]Generating chunking plans...[/bold magenta]") chunk_size_plans = self.plan_generator._generate_chunk_size_plans( op_config, input_data, validator_prompt, model_input_context_length @@ -290,6 +299,7 @@ def optimize( output=candidate_plans, ) + self.console.post_optimizer_status(StageType.EVALUATION_RESULTS) self.console.log( f"[bold magenta]Evaluating {len(plans_list)} plans...[/bold magenta]" ) @@ -349,6 +359,7 @@ def optimize( # Check if there are no top plans if len(top_plans) == 0: + self.console.post_optimizer_status(StageType.END) raise ValueError( "Agent did not generate any plans. Unable to proceed with optimization. Try again." ) @@ -422,6 +433,7 @@ def optimize( }, ) + self.console.post_optimizer_status(StageType.END) return ( candidate_plans[best_plan_name], best_output, diff --git a/docetl/optimizers/map_optimizer/plan_generators.py b/docetl/optimizers/map_optimizer/plan_generators.py index 9f8a5f40..a3c1f9bd 100644 --- a/docetl/optimizers/map_optimizer/plan_generators.py +++ b/docetl/optimizers/map_optimizer/plan_generators.py @@ -781,9 +781,6 @@ def _generate_chain_plans( """ output_schema = op_config["output"]["schema"] - if len(output_schema) <= 1: - return {} # No need for chain decomposition if there's only one output key - variables_in_prompt = extract_jinja_variables(op_config["prompt"]) variables_in_prompt = [v.replace("input.", "") for v in variables_in_prompt] diff --git a/docetl/optimizers/map_optimizer/prompt_generators.py b/docetl/optimizers/map_optimizer/prompt_generators.py index f2bde628..91e18834 100644 --- a/docetl/optimizers/map_optimizer/prompt_generators.py +++ b/docetl/optimizers/map_optimizer/prompt_generators.py @@ -192,7 +192,7 @@ def _get_header_extraction_prompt( header_extraction_prompt = f"""Analyze the following chunk of a document and extract any headers you see. - {{ input.{split_key}_chunk }} + {{{{ input.{split_key}_chunk }}}} Examples of headers and their levels based on the document structure: {chr(10).join(header_examples)} @@ -331,15 +331,41 @@ def _get_combine_prompt( {sample_inputs} Modify the original prompt to be a prompt that will combine these chunk results to accomplish the original task. + This prompt will be submitted to an LLM, so it must be a valid Jinja2 template, with natural language instructions. Guidelines for your prompt template: - The only variable you are allowed to use is the `inputs` variable, which contains all chunk results. Each value is a dictionary with the keys {', '.join(schema_keys)} - - Avoid using filters or complex logic, even though Jinja technically supports it + - Avoid using filters or complex logic like `do` statements, even though Jinja technically supports it - The prompt template must be a valid Jinja2 template - You must use the {{{{ inputs }}}} variable somehow, in a for loop. You must access specific keys in each item in the loop. + - The prompt template must also contain natural language instructions so the LLM knows what to do with the data Provide your prompt template as a single string. """ + # Add example for combining themes + base_prompt += """ + Example of a good combine prompt for combining themes: + ``` + You are tasked with combining themes extracted from different chunks of text. + + Here are the themes extracted from each chunk: + {% for item in inputs %} + Themes for chunk {loop.index}: + {{ item.themes }} + {% endfor %} + + Analyze all the themes above and create a consolidated list that: + 1. Combines similar or related themes + 2. Preserves unique themes that appear in only one chunk + 3. Prioritizes themes that appear multiple times across chunks + 4. Maintains the original wording where possible + + Provide the final consolidated list of themes, ensuring each theme is distinct and meaningful. + ``` + + Now generate a combine prompt for the current task. + """ + parameters = { "type": "object", "properties": {"combine_prompt": {"type": "string"}}, diff --git a/docetl/optimizers/reduce_optimizer.py b/docetl/optimizers/reduce_optimizer.py index feee3f8b..dcbc39fc 100644 --- a/docetl/optimizers/reduce_optimizer.py +++ b/docetl/optimizers/reduce_optimizer.py @@ -15,7 +15,7 @@ from docetl.operations.utils import truncate_messages from docetl.optimizers.join_optimizer import JoinOptimizer from docetl.optimizers.utils import LLMClient -from docetl.utils import count_tokens, extract_jinja_variables +from docetl.utils import count_tokens, extract_jinja_variables, StageType class ReduceOptimizer: @@ -149,9 +149,11 @@ def optimize( # # Return unoptimized map and reduce operations # return [map_prompt, op_config], input_data, 0.0 + self.console.post_optimizer_status(StageType.SAMPLE_RUN) original_output = self._run_operation(op_config, input_data) # Step 1: Synthesize a validator prompt + self.console.post_optimizer_status(StageType.SHOULD_OPTIMIZE) validator_prompt = self._generate_validator_prompt( op_config, input_data, original_output ) @@ -172,6 +174,11 @@ def optimize( # Print the validation results self.console.log("[bold]Validation Results on Initial Sample:[/bold]") if validation_results["needs_improvement"]: + self.console.post_optimizer_rationale( + should_optimize=True, + rationale="\n".join(validation_results["issues"]), + validator_prompt=validator_prompt, + ) self.console.log( "\n".join( [ @@ -302,6 +309,7 @@ def _optimize_single_reduce( is_associative = self._is_associative(op_config, input_data) # Step 3: Create and evaluate multiple reduce plans + self.console.post_optimizer_status(StageType.CANDIDATE_PLANS) self.console.log("[bold magenta]Generating batched plans...[/bold magenta]") reduce_plans = self._create_reduce_plans(op_config, input_data, is_associative) @@ -310,12 +318,14 @@ def _optimize_single_reduce( gleaning_plans = self._generate_gleaning_plans(reduce_plans, validator_prompt) self.console.log("[bold magenta]Evaluating plans...[/bold magenta]") + self.console.post_optimizer_status(StageType.EVALUATION_RESULTS) best_plan = self._evaluate_reduce_plans( op_config, reduce_plans + gleaning_plans, input_data, validator_prompt ) # Step 4: Run the best reduce plan optimized_output = self._run_operation(best_plan, input_data) + self.console.post_optimizer_status(StageType.END) return [best_plan], optimized_output, 0.0 diff --git a/docetl/runner.py b/docetl/runner.py index 5acf84cf..18c73eb2 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -487,15 +487,14 @@ def optimize( builder = Optimizer( self, - max_threads=self.max_threads, **kwargs, ) cost = builder.optimize() - - # Dump via pickle - import pickle - with open(f"{self.base_name}_optimizer_output.pkl", "wb") as f: - pickle.dump(builder.captured_output, f) + + # Dump via json + # import json + # with open(f"{self.base_name}_optimizer_output.json", "wb") as f: + # json.dump(builder.captured_output.optimizer_output, f) if save: diff --git a/docetl/utils.py b/docetl/utils.py index 18bcb454..d190ecb5 100644 --- a/docetl/utils.py +++ b/docetl/utils.py @@ -1,18 +1,32 @@ import json import re from typing import Any, Dict, List - +from enum import Enum import tiktoken import yaml from jinja2 import Environment, meta from litellm import completion_cost as lcc -class StageType: +class StageType(Enum): SAMPLE_RUN = "sample_run" SHOULD_OPTIMIZE = "should_optimize" CANDIDATE_PLANS = "candidate_plans" EVALUATION_RESULTS = "evaluation_results" + END = "end" + +def get_stage_description(stage_type: StageType) -> str: + if stage_type == StageType.SAMPLE_RUN: + return "Running samples..." + elif stage_type == StageType.SHOULD_OPTIMIZE: + return "Checking if optimization is needed..." + elif stage_type == StageType.CANDIDATE_PLANS: + return "Generating candidate plans..." + elif stage_type == StageType.EVALUATION_RESULTS: + return "Evaluating candidate plans..." + elif stage_type == StageType.END: + return "Optimization complete!" + raise ValueError(f"Unknown stage type: {stage_type}") class CapturedOutput: def __init__(self): diff --git a/server/app/routes/pipeline.py b/server/app/routes/pipeline.py index 666b43f0..38ec4e88 100644 --- a/server/app/routes/pipeline.py +++ b/server/app/routes/pipeline.py @@ -1,9 +1,16 @@ +import os +import signal from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from server.app.models import PipelineRequest from docetl.runner import DSLRunner import asyncio -import queue +from rich.logging import RichHandler +import logging +FORMAT = "%(message)s" +logging.basicConfig( + level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] +) router = APIRouter() @@ -29,9 +36,10 @@ async def websocket_run_pipeline(websocket: WebSocket): runner.clear_intermediate() if config.get("optimize", False): - + logging.info(f"Optimizing pipeline with model {config.get('optimizer_model', 'gpt-4o')}") + async def run_pipeline(): - return await asyncio.to_thread(runner.optimize, return_pipeline=False) + return await asyncio.to_thread(runner.optimize, return_pipeline=False, model=config.get("optimizer_model", "gpt-4o")) else: @@ -44,11 +52,44 @@ async def run_pipeline(): console_output = runner.console.file.getvalue() await websocket.send_json({"type": "output", "data": console_output}) + if config.get("optimize", False): + optimizer_progress = runner.console.get_optimizer_progress() + rationale = runner.console.optimizer_rationale + await websocket.send_json({ + "type": "optimizer_progress", + "status": optimizer_progress[0], + "progress": optimizer_progress[1], + "rationale": rationale[1] if rationale is not None else "", + "should_optimize": rationale[0] if rationale is not None else False, + "validator_prompt": rationale[2] if rationale is not None else "" + }) + # Check for incoming messages from the user try: user_message = await asyncio.wait_for( websocket.receive_json(), timeout=0.1 ) + + if user_message == "kill": + runner.console.print("Killing process...") + await websocket.send_json({ + "type": "error", + "message": "Killing process. Service will restart automatically." + }) + # Close websocket cleanly + await websocket.close() + + # Get current process ID + pid = os.getpid() + + # Schedule the process to kill itself + async def delayed_kill(): + await asyncio.sleep(0.5) # Give time for websocket to close + os.kill(pid, signal.SIGTERM) + + asyncio.create_task(delayed_kill()) + return + # Process the user message and send it to the runner runner.console.post_input(user_message) except asyncio.TimeoutError: @@ -69,17 +110,15 @@ async def run_pipeline(): # If optimize is true, send back the optimized operations if config.get("optimize", False): optimized_config, cost = result - # find the operation that has optimize = true - optimized_op = None - for op in optimized_config["operations"]: - if op.get("optimize", False): - optimized_op = op - break - - if not optimized_op: - raise HTTPException( - status_code=500, detail="No optimized operation found" - ) + + # Send the operations back in order + new_pipeline_steps = optimized_config["pipeline"]["steps"] + new_pipeline_op_name_to_op_map = {op["name"]: op for op in optimized_config["operations"]} + new_ops_in_order = [] + for new_step in new_pipeline_steps: + for op in new_step.get("operations", []): + if op not in new_ops_in_order: + new_ops_in_order.append(new_pipeline_op_name_to_op_map[op]) await websocket.send_json( { @@ -87,7 +126,7 @@ async def run_pipeline(): "data": { "message": "Pipeline executed successfully", "cost": cost, - "optimized_op": optimized_op, + "optimized_ops": new_ops_in_order, }, } ) @@ -108,4 +147,4 @@ async def run_pipeline(): error_traceback = traceback.format_exc() print(f"Error occurred:\n{error_traceback}") - await websocket.send_json({"type": "error", "data": str(e)}) + await websocket.send_json({"type": "error", "data": str(e) + "\n" + error_traceback}) diff --git a/website/package-lock.json b/website/package-lock.json index 5420756c..58cc7248 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -22,7 +22,8 @@ "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-menubar": "^1.1.2", - "@radix-ui/react-popover": "^1.1.2", + "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.1.0", "@radix-ui/react-scroll-area": "^1.1.0", "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.1.0", @@ -3039,6 +3040,30 @@ } } }, + "node_modules/@radix-ui/react-progress": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.1.0.tgz", + "integrity": "sha512-aSzvnYpP725CROcxAOEBVZZSIQVQdHgBr2QQFKySsaD14u8dNT0batuXI+AAGDdAHfXH8rbnHmjYFqVJ21KkRg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-context": "1.1.0", + "@radix-ui/react-primitive": "2.0.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-roving-focus": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.0.tgz", diff --git a/website/package.json b/website/package.json index 91ab0de2..4ca7c52c 100644 --- a/website/package.json +++ b/website/package.json @@ -24,6 +24,7 @@ "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-menubar": "^1.1.2", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.1.0", "@radix-ui/react-scroll-area": "^1.1.0", "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.1.0", diff --git a/website/src/app/api/utils.ts b/website/src/app/api/utils.ts index 03fa50ae..0671fa92 100644 --- a/website/src/app/api/utils.ts +++ b/website/src/app/api/utils.ts @@ -52,6 +52,13 @@ export function generatePipelineConfig( delete newOp.id; delete newOp.llmType; + if ( + op.gleaning && + (op.gleaning.num_rounds === 0 || !op.gleaning.validation_prompt) + ) { + delete newOp.gleaning; + } + if (!op.output || !op.output.schema) return newOp; const processSchemaItem = (item: SchemaItem): string => { diff --git a/website/src/app/localStorageKeys.ts b/website/src/app/localStorageKeys.ts index 786b6c85..5e0ea8d3 100644 --- a/website/src/app/localStorageKeys.ts +++ b/website/src/app/localStorageKeys.ts @@ -13,3 +13,4 @@ export const SAMPLE_SIZE_KEY = "docetl_sampleSize"; export const FILES_KEY = "docetl_files"; export const COST_KEY = "docetl_cost"; export const DEFAULT_MODEL_KEY = "docetl_defaultModel"; +export const OPTIMIZER_MODEL_KEY = "docetl_optimizerModel"; diff --git a/website/src/app/types.ts b/website/src/app/types.ts index 1afab46b..de10eee8 100644 --- a/website/src/app/types.ts +++ b/website/src/app/types.ts @@ -20,6 +20,7 @@ export type Operation = { prompt?: string; output?: { schema: SchemaItem[] }; validate?: string[]; + gleaning?: { num_rounds: number; validation_prompt: string }; otherKwargs?: Record; runIndex?: number; sample?: number; @@ -60,7 +61,7 @@ export interface BookmarkContextType { text: string, source: string, color: string, - notes: UserNote[], + notes: UserNote[] ) => void; removeBookmark: (id: string) => void; } diff --git a/website/src/components/AnsiRenderer.tsx b/website/src/components/AnsiRenderer.tsx index 5151760b..937d240f 100644 --- a/website/src/components/AnsiRenderer.tsx +++ b/website/src/components/AnsiRenderer.tsx @@ -33,20 +33,24 @@ const AnsiRenderer: React.FC = ({ }, [text]); const handleSendMessage = () => { - if (userInput.trim()) { - sendMessage(userInput); - setUserInput(""); - } + sendMessage(userInput.trim()); + setUserInput(""); }; + const isWebSocketClosed = readyState === WebSocket.CLOSED; + return ( -
+
 = ({
             value={userInput}
             onChange={(e) => setUserInput(e.target.value)}
             onKeyPress={(e) => e.key === "Enter" && handleSendMessage()}
-            className="flex-grow bg-gray-800 text-white px-2 py-1 rounded-l"
-            placeholder="Type a message..."
+            className={`flex-grow bg-gray-800 text-white px-2 py-1 rounded-l ${
+              isWebSocketClosed ? "cursor-not-allowed" : ""
+            }`}
+            placeholder={
+              isWebSocketClosed
+                ? "WebSocket disconnected..."
+                : "Type a message..."
+            }
+            disabled={isWebSocketClosed}
           />
           
         
-
+
WebSocket State:{" "} {readyState === WebSocket.CONNECTING ? "Connecting" : readyState === WebSocket.OPEN - ? "Open" - : readyState === WebSocket.CLOSING - ? "Closing" - : readyState === WebSocket.CLOSED - ? "Closed" - : "Unknown"} + ? "Open" + : readyState === WebSocket.CLOSING + ? "Closing" + : readyState === WebSocket.CLOSED + ? "Closed" + : "Unknown"}
diff --git a/website/src/components/OperationCard.tsx b/website/src/components/OperationCard.tsx index 0c953730..39a7b14f 100644 --- a/website/src/components/OperationCard.tsx +++ b/website/src/components/OperationCard.tsx @@ -30,13 +30,14 @@ import { Settings, ListCollapse, Wand2, + ChevronDown, } from "lucide-react"; import { Operation, SchemaItem } from "@/app/types"; import { usePipelineContext } from "@/contexts/PipelineContext"; import { useToast } from "@/hooks/use-toast"; import { Skeleton } from "@/components/ui/skeleton"; import { debounce } from "lodash"; -import { Guardrails } from "./operations/args"; +import { Guardrails, GleaningConfig } from "./operations/args"; import createOperationComponent from "./operations/components"; import { useWebSocket } from "@/contexts/WebSocketContext"; import { Badge } from "./ui/badge"; @@ -54,6 +55,7 @@ const OperationHeader: React.FC<{ llmType: string; disabled: boolean; currOp: boolean; + expanded: boolean; onEdit: (name: string) => void; onDelete: () => void; onRunOperation: () => void; @@ -61,6 +63,7 @@ const OperationHeader: React.FC<{ onShowOutput: () => void; onOptimize: () => void; onAIEdit: (instruction: string) => void; + onToggleExpand: () => void; }> = React.memo( ({ name, @@ -68,6 +71,7 @@ const OperationHeader: React.FC<{ llmType, disabled, currOp, + expanded, onEdit, onDelete, onRunOperation, @@ -75,6 +79,7 @@ const OperationHeader: React.FC<{ onShowOutput, onOptimize, onAIEdit, + onToggleExpand, }) => { const [isEditing, setIsEditing] = useState(false); const [editedName, setEditedName] = useState(name); @@ -93,24 +98,38 @@ const OperationHeader: React.FC<{
{/* Left side buttons */}
- {type} + {type} + {["resolve", "map", "reduce", "filter"].includes(type) && ( + + )} @@ -331,7 +350,13 @@ type Action = | { type: "TOGGLE_GUARDRAILS" } | { type: "TOGGLE_SETTINGS" } | { type: "SET_RUN_INDEX"; payload: number } - | { type: "UPDATE_SETTINGS"; payload: Record }; + | { type: "UPDATE_SETTINGS"; payload: Record } + | { type: "TOGGLE_EXPAND" } + | { + type: "UPDATE_GLEANINGS"; + payload: { num_rounds: number; validation_prompt: string }; + } + | { type: "TOGGLE_GLEANINGS" }; // State type type State = { @@ -340,6 +365,8 @@ type State = { isSchemaExpanded: boolean; isGuardrailsExpanded: boolean; isSettingsOpen: boolean; + isExpanded: boolean; + isGleaningsExpanded: boolean; }; // Reducer function @@ -401,6 +428,17 @@ function operationReducer(state: State, action: Action): State { operation: { ...state.operation, runIndex: action.payload }, } : state; + case "TOGGLE_EXPAND": + return { ...state, isExpanded: !state.isExpanded }; + case "UPDATE_GLEANINGS": + return state.operation + ? { + ...state, + operation: { ...state.operation, gleaning: action.payload }, + } + : state; + case "TOGGLE_GLEANINGS": + return { ...state, isGleaningsExpanded: !state.isGleaningsExpanded }; default: return state; } @@ -413,6 +451,8 @@ const initialState: State = { isSchemaExpanded: false, isGuardrailsExpanded: false, isSettingsOpen: false, + isExpanded: true, + isGleaningsExpanded: false, }; // Main component @@ -424,6 +464,8 @@ export const OperationCard: React.FC<{ index: number }> = ({ index }) => { isSchemaExpanded, isGuardrailsExpanded, isSettingsOpen, + isExpanded, + isGleaningsExpanded, } = state; const { @@ -440,6 +482,7 @@ export const OperationCard: React.FC<{ index: number }> = ({ index }) => { sampleSize, setCost, defaultModel, + optimizerModel, setTerminalOutput, } = usePipelineContext(); const { toast } = useToast(); @@ -611,6 +654,7 @@ export const OperationCard: React.FC<{ index: number }> = ({ index }) => { sendMessage({ yaml_config: filePath, optimize: true, + optimizer_model: optimizerModel, }); } catch (error) { console.error("Error optimizing operation:", error); @@ -754,6 +798,7 @@ export const OperationCard: React.FC<{ index: number }> = ({ index }) => { llmType={operation.llmType} disabled={isLoadingOutputs || pipelineOutput === undefined} currOp={operation.id === pipelineOutput?.operationId} + expanded={isExpanded} onEdit={(name) => { dispatch({ type: "UPDATE_NAME", payload: name }); debouncedUpdate(); @@ -768,27 +813,49 @@ export const OperationCard: React.FC<{ index: number }> = ({ index }) => { onShowOutput={onShowOutput} onOptimize={onOptimize} onAIEdit={handleAIEdit} + onToggleExpand={() => dispatch({ type: "TOGGLE_EXPAND" })} /> - - {createOperationComponent( - operation, - handleOperationUpdate, - isSchemaExpanded, - () => dispatch({ type: "TOGGLE_SCHEMA" }) - )} - - {operation.llmType === "LLM" && ( - - dispatch({ - type: "UPDATE_GUARDRAILS", - payload: newGuardrails, - }) - } - isExpanded={isGuardrailsExpanded} - onToggle={() => dispatch({ type: "TOGGLE_GUARDRAILS" })} - /> + {isExpanded && ( + <> + + {createOperationComponent( + operation, + handleOperationUpdate, + isSchemaExpanded, + () => dispatch({ type: "TOGGLE_SCHEMA" }) + )} + + {operation.llmType === "LLM" && ( + <> + + dispatch({ + type: "UPDATE_GUARDRAILS", + payload: newGuardrails, + }) + } + isExpanded={isGuardrailsExpanded} + onToggle={() => dispatch({ type: "TOGGLE_GUARDRAILS" })} + /> + + )} + {(operation.type === "map" || + operation.type === "reduce" || + operation.type === "filter") && ( + + dispatch({ + type: "UPDATE_GLEANINGS", + payload: newGleanings, + }) + } + isExpanded={isGleaningsExpanded} + onToggle={() => dispatch({ type: "TOGGLE_GLEANINGS" })} + /> + )} + )} { - const { terminalOutput, setTerminalOutput } = usePipelineContext(); + const { terminalOutput, setTerminalOutput, optimizerProgress } = + usePipelineContext(); const { readyState } = useWebSocket(); return (
+ {optimizerProgress && ( +
+
+
+ {optimizerProgress.status} +
+
+ {Math.round(optimizerProgress.progress * 100)}% +
+
+ + {optimizerProgress.shouldOptimize && ( +
+
+
+ Optimizing because +
+
+ {optimizerProgress.rationale} +
+
+ + {optimizerProgress.validatorPrompt && ( +
+
+ Using this prompt to find the best plan +
+
+ {optimizerProgress.validatorPrompt} +
+
+ )} +
+ )} +
+ )} { useEffect(() => { const foundOperation = operations.find( - (op: Operation) => op.id === output?.operationId, + (op: Operation) => op.id === output?.operationId ); setOperation(foundOperation); setOpName(foundOperation?.name); setIsResolveOrReduce( - foundOperation?.type === "resolve" || foundOperation?.type === "reduce", + foundOperation?.type === "resolve" || foundOperation?.type === "reduce" ); }, [operations, output]); @@ -76,7 +113,7 @@ export const Output: React.FC = () => { try { // Fetch output data const outputResponse = await fetch( - `/api/readFile?path=${output.path}`, + `/api/readFile?path=${output.path}` ); if (!outputResponse.ok) { throw new Error("Failed to fetch output file"); @@ -121,7 +158,7 @@ export const Output: React.FC = () => { // Fetch input data if inputPath exists if (output.inputPath) { const inputResponse = await fetch( - `/api/readFile?path=${output.inputPath}`, + `/api/readFile?path=${output.inputPath}` ); if (!inputResponse.ok) { throw new Error("Failed to fetch input file"); @@ -129,7 +166,7 @@ export const Output: React.FC = () => { const inputContent = await inputResponse.text(); const parsedInputs = JSON.parse(inputContent); setInputCount( - Array.isArray(parsedInputs) ? parsedInputs.length : 1, + Array.isArray(parsedInputs) ? parsedInputs.length : 1 ); } else { setInputCount(0); @@ -205,8 +242,8 @@ export const Output: React.FC = () => { return outputs.length > 0 && reduceColumnName in outputs[0] ? { name: reduceColumnName, type: "reduce" } : outputs.length > 0 && resolveColumnName in outputs[0] - ? { name: resolveColumnName, type: "resolve" } - : null; + ? { name: resolveColumnName, type: "resolve" } + : null; }, [outputs, opName, operation]); if (!visualizationColumn || !operation) { @@ -225,7 +262,7 @@ export const Output: React.FC = () => { .sort( (a, b) => Number(b[visualizationColumn.name]) - - Number(a[visualizationColumn.name]), + Number(a[visualizationColumn.name]) ) .map((row, index) => (
@@ -252,7 +289,7 @@ export const Output: React.FC = () => { outputs.flatMap((row) => { const kvPairs = row[visualizationColumn.name]; return Object.keys(kvPairs).filter((key) => key in row); - }), + }) ); const groupedByIntersection: { [key: string]: any[] } = {}; diff --git a/website/src/components/PipelineGui.tsx b/website/src/components/PipelineGui.tsx index 4e590fe3..b8332cf9 100644 --- a/website/src/components/PipelineGui.tsx +++ b/website/src/components/PipelineGui.tsx @@ -23,6 +23,8 @@ import { Download, FileUp, Save, + Loader2, + StopCircle, } from "lucide-react"; import { usePipelineContext } from "@/contexts/PipelineContext"; import { @@ -78,9 +80,14 @@ const PipelineGUI: React.FC = () => { setTerminalOutput, saveProgress, clearPipelineState, + optimizerModel, + setOptimizerModel, + optimizerProgress, + setOptimizerProgress, } = usePipelineContext(); const [isSettingsOpen, setIsSettingsOpen] = useState(false); const [tempPipelineName, setTempPipelineName] = useState(pipelineName); + const [tempOptimizerModel, setTempOptimizerModel] = useState(defaultModel); const [tempSampleSize, setTempSampleSize] = useState( sampleSize?.toString() || "" ); @@ -91,55 +98,73 @@ const PipelineGUI: React.FC = () => { const { toast } = useToast(); const { connect, sendMessage, lastMessage, readyState, disconnect } = useWebSocket(); + const [runningButtonType, setRunningButtonType] = useState< + "run" | "clear-run" | null + >(null); useEffect(() => { if (lastMessage) { if (lastMessage.type === "output") { setTerminalOutput(lastMessage.data); + } else if (lastMessage.type === "optimizer_progress") { + setOptimizerProgress({ + status: lastMessage.status, + progress: lastMessage.progress, + shouldOptimize: lastMessage.should_optimize, + rationale: lastMessage.rationale, + validatorPrompt: lastMessage.validator_prompt, + }); } else if (lastMessage.type === "result") { const runCost = lastMessage.data.cost || 0; + setOptimizerProgress(null); // See if there was an optimized operation - const optimizedOp = lastMessage.data.optimized_op; - if (optimizedOp) { - const { - id, - llmType, - type, - name, - prompt, - output, - validate, - sample, - ...otherKwargs - } = optimizedOp; - const convertedOp = { - id: id || crypto.randomUUID(), - llmType: - type === "map" || - type === "reduce" || - type === "resolve" || - type === "filter" || - type === "parallel_map" - ? "LLM" - : "non-LLM", - type: type, - name: name || "Untitled Operation", - prompt: prompt, - output: output - ? { - schema: schemaDictToItemSet(output.schema), - } - : undefined, - validate: validate, - sample: sample, - otherKwargs: otherKwargs || {}, - }; - setOperations((prev) => - prev.map((op) => - op.name === optimizedOp.name ? (convertedOp as Operation) : op - ) - ); + const optimizedOps = lastMessage.data.optimized_ops; + if (optimizedOps) { + const newOperations = optimizedOps.map((optimizedOp) => { + const { + id, + llmType, + type, + name, + prompt, + output, + validate, + gleaning, + sample, + ...otherKwargs + } = optimizedOp; + + // Find matching operation in previous operations list + const existingOp = operations.find((op) => op.name === name); + + return { + id: id || crypto.randomUUID(), + llmType: + type === "map" || + type === "reduce" || + type === "resolve" || + type === "filter" || + type === "parallel_map" + ? "LLM" + : "non-LLM", + type: type, + name: name || "Untitled Operation", + prompt: prompt, + output: output + ? { + schema: schemaDictToItemSet(output.schema), + } + : undefined, + validate: validate, + gleaning: gleaning, + sample: sample, + otherKwargs: otherKwargs || {}, + ...(existingOp?.runIndex && { runIndex: existingOp.runIndex }), + } as Operation; + }); + + setOperations(newOperations); } setCost((prevCost) => prevCost + runCost); @@ -186,6 +211,12 @@ const PipelineGUI: React.FC = () => { } }, [currentFile]); + useEffect(() => { + if (optimizerModel) { + setTempDefaultModel(tempOptimizerModel); + } + }, [optimizerModel]); + const handleFileUpload = async ( event: React.ChangeEvent ) => { @@ -353,7 +384,9 @@ const PipelineGUI: React.FC = () => { if (lastOpIndex < 0) return; const lastOperation = operations[lastOpIndex]; + setOptimizerProgress(null); setIsLoadingOutputs(true); + setRunningButtonType(clear_intermediate ? "clear-run" : "run"); setNumOpRun((prevNum) => { const newNum = prevNum + operations.length; const updatedOperations = operations.map((op, index) => ({ @@ -412,6 +445,7 @@ const PipelineGUI: React.FC = () => { // Close the WebSocket connection disconnect(); setIsLoadingOutputs(false); + setRunningButtonType(null); } }, [ @@ -453,6 +487,7 @@ const PipelineGUI: React.FC = () => { setCurrentFile(tempCurrentFile); setDefaultModel(tempDefaultModel); setIsSettingsOpen(false); + setOptimizerModel(tempOptimizerModel); }; const handleDragEnd = (result: DropResult) => { @@ -473,297 +508,319 @@ const PipelineGUI: React.FC = () => { } }; + const handleStop = () => { + sendMessage("kill"); + setRunningButtonType(null); + }; + return ( -
-
-
-
-
-

- {pipelineName.toUpperCase()} -

- {sampleSize && ( - - - -
- - - {sampleSize} samples - -
-
- -

- Pipeline will run on a sample of {sampleSize} random - documents. -

-
-
-
- )} -
- - - - - - -

Initialize from config file

-
-
-
- - - - - - - -

Download pipeline config file

-
-
-
- - -
-
-
- - - - - - LLM Operations - - handleAddOperation("LLM", "map", "Untitled Map") - } - > - Map - - - handleAddOperation("LLM", "reduce", "Untitled Reduce") - } - > - Reduce - - - handleAddOperation("LLM", "resolve", "Untitled Resolve") - } - > - Resolve - - - handleAddOperation("LLM", "filter", "Untitled Filter") - } - > - Filter - - - handleAddOperation( - "LLM", - "parallel_map", - "Untitled Parallel Map" - ) - } - > - Parallel Map - - - Non-LLM Operations - - handleAddOperation("non-LLM", "unnest", "Untitled Unnest") - } - > - Unnest - - - handleAddOperation("non-LLM", "split", "Untitled Split") - } - > - Split - - - handleAddOperation("non-LLM", "gather", "Untitled Gather") - } - > - Gather - - - handleAddOperation("non-LLM", "sample", "Untitled Sample") - } - > - Sample - - - -
- - - - - - -

The cache will be cleared before running

-
-
-
- - - - - - -

This will use any cached outputs if applicable

-
-
-
-
+
+
+
+
+

+ {pipelineName.toUpperCase()} +

+ {sampleSize && ( + + + +
+ + + {sampleSize} samples + +
+
+ +

+ Pipeline will run on a sample of {sampleSize} random + documents. +

+
+
+
+ )} +
+ + + + + + +

Initialize from config file

+
+
+
+ + + + + + + +

Download pipeline config file

+
+
+
+ +
-
-
- - - {(provided, snapshot) => ( -
+ + + + + + LLM Operations + + handleAddOperation("LLM", "map", "Untitled Map") + } > - {operations.map((op, index) => ( - - ))} - {provided.placeholder} -
- )} -
-
-
- - - - Pipeline Settings - -
-
- - setTempPipelineName(e.target.value)} - className="col-span-3" - /> -
-
- - setTempSampleSize(e.target.value)} - placeholder="None" - className="col-span-3" - /> -
-
- - -
-
- - setTempDefaultModel(e.target.value)} - className="col-span-3" - /> -
+ Parallel Map + + + Non-LLM Operations + + handleAddOperation("non-LLM", "unnest", "Untitled Unnest") + } + > + Unnest + + + handleAddOperation("non-LLM", "split", "Untitled Split") + } + > + Split + + + handleAddOperation("non-LLM", "gather", "Untitled Gather") + } + > + Gather + + + handleAddOperation("non-LLM", "sample", "Untitled Sample") + } + > + Sample + + + +
+ + +
- - - - -
+
+
+
+
+ + + {(provided, snapshot) => ( +
+ {operations.map((op, index) => ( + + ))} + {provided.placeholder} +
+ )} +
+
+ + + + Pipeline Settings + +
+
+ + setTempPipelineName(e.target.value)} + className="col-span-3" + /> +
+
+ + setTempSampleSize(e.target.value)} + placeholder="None" + className="col-span-3" + /> +
+
+ + +
+
+ + setTempDefaultModel(e.target.value)} + className="col-span-3" + /> +
+
+ + +
+
+ + + +
+
); }; diff --git a/website/src/components/operations/args.tsx b/website/src/components/operations/args.tsx index 76f99bd8..abb8d9dc 100644 --- a/website/src/components/operations/args.tsx +++ b/website/src/components/operations/args.tsx @@ -17,21 +17,46 @@ import { TooltipProvider, TooltipTrigger, } from "../ui/tooltip"; +import { Switch } from "../ui/switch"; +import { Label } from "../ui/label"; -export const PromptInput: React.FC<{ +interface PromptInputProps { prompt: string; onChange: (value: string) => void; -}> = React.memo(({ prompt, onChange }) => { - return ( -