diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index 09144e77..cc6a8209 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -167,31 +167,33 @@ def validation_fn(response: Dict[str, Any]): return output, True return output, False - output, cost, success = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], + response = self.runner.api.call_llm( model=self.config.get("model", self.default_model), - operation_type="cluster", - schema=self.config["summary_schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm( - self.config.get("model", self.default_model), - "cluster", - messages, - self.config["summary_schema"], - tools=self.config.get("tools", None), - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), + op_type="cluster", + messages=[{"role": "user", "content": prompt}], + output_schema=self.config["summary_schema"], + timeout_seconds=self.config.get("timeout", 120), + bypass_cache=self.config.get("bypass_cache", False), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": validation_fn, + } + if self.config.get("validate", None) + else None ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, + verbose=self.config.get("verbose", False), ) - total_cost += cost - - t.update(output) + total_cost += response.total_cost + if response.validated: + output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["summary_schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + t.update(output) return total_cost return 0 diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index e56a7b39..d6cf6469 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -82,9 +82,12 @@ def compare_pair( {"is_match": "bool"}, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, + bypass_cache=self.config.get("bypass_cache", False), ) - output = self.runner.api.parse_llm_response(response, {"is_match": "bool"})[0] - return output["is_match"], completion_cost(response) + output = self.runner.api.parse_llm_response( + response.response, {"is_match": "bool"} + )[0] + return output["is_match"], response.total_cost def syntax_check(self) -> None: """ diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index d67042e9..48037eef 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -5,13 +5,13 @@ from jinja2 import Template -from docetl.operations.base import BaseOperation +from docetl.operations.map import MapOperation from docetl.operations.utils import ( RichLoopBar, ) -class FilterOperation(BaseOperation): +class FilterOperation(MapOperation): def syntax_check(self) -> None: """ Checks the configuration of the FilterOperation for required keys and valid structure. @@ -110,77 +110,9 @@ def execute( ) ) - if self.status: - self.status.start() - - def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]: - prompt_template = Template(self.config["prompt"]) - prompt = prompt_template.render(input=item) - - def validation_fn(response: Dict[str, Any]): - output = self.runner.api.parse_llm_response( - response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - for key, value in item.items(): - if key not in self.config["output"]["schema"]: - output[key] = value - if self.runner.api.validate_output(self.config, output, self.console): - return output, True - return output, False - - output, cost, is_valid = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], - model=self.config.get("model", self.default_model), - operation_type="filter", - schema=self.config["output"]["schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm( - self.config.get("model", self.default_model), - "filter", - messages, - self.config["output"]["schema"], - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), - ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, - ) + results, total_cost = super().execute(input_data) - if is_valid: - return output, cost - - return None, cost - - with ThreadPoolExecutor(max_workers=self.max_threads) as executor: - futures = [ - executor.submit(_process_filter_item, item) for item in input_data - ] - results = [] - total_cost = 0 - pbar = RichLoopBar( - range(len(futures)), - desc=f"Processing {self.config['name']} (filter) on all documents", - console=self.console, - ) - for i in pbar: - future = futures[i] - result, item_cost = future.result() - total_cost += item_cost - if result is not None: - if is_build: - results.append(result) - else: - if result.get(filter_key, False): - results.append(result) - pbar.update(1) - - if self.status: - self.status.start() + # Drop records with filter_key values that are False + results = [result for result in results if result[filter_key]] return results, total_cost diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 3f077713..300e8419 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -153,59 +153,42 @@ def validation_fn(response: Dict[str, Any]): return output, False self.runner.rate_limiter.try_acquire("call", weight=1) - if "gleaning" in self.config: - output, cost, success = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], - model=self.config.get("model", self.default_model), - operation_type="map", - schema=self.config["output"]["schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm_with_gleaning( - self.config.get("model", self.default_model), - "map", - messages, - self.config["output"]["schema"], - self.config["gleaning"]["validation_prompt"], - self.config["gleaning"]["num_rounds"], - self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), - verbose=self.config.get("verbose", False), - ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, - ) - else: - output, cost, success = self.runner.api.call_llm_with_validation( - [{"role": "user", "content": prompt}], - model=self.config.get("model", self.default_model), - operation_type="map", - schema=self.config["output"]["schema"], - llm_call_fn=lambda messages: self.runner.api.call_llm( - self.config.get("model", self.default_model), - "map", - messages, - self.config["output"]["schema"], - tools=self.config.get("tools", None), - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get( - "max_retries_per_timeout", 2 - ), - ), - validation_fn=validation_fn, - val_rule=self.config.get("validate", []), - num_retries=self.num_retries_on_validate_failure, - console=self.console, - ) + llm_result = self.runner.api.call_llm( + self.config.get("model", self.default_model), + "map", + [{"role": "user", "content": prompt}], + self.config["output"]["schema"], + tools=self.config.get("tools", None), + scratchpad=None, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": validation_fn, + } + if self.config.get("validate", None) + else None + ), + gleaning_config=self.config.get("gleaning", None), + verbose=self.config.get("verbose", False), + bypass_cache=self.config.get("bypass_cache", False), + ) - if success: - return output, cost + if llm_result.validated: + # Parse the response + output = self.runner.api.parse_llm_response( + llm_result.response, + schema=self.config["output"]["schema"], + tools=self.config.get("tools", None), + manually_fix_errors=self.manually_fix_errors, + )[0] + # Augment the output with the original item + output = {**item, **output} + return output, llm_result.total_cost - return None, cost + return None, llm_result.total_cost with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: futures = [executor.submit(_process_map_item, item) for item in input_data] @@ -375,17 +358,17 @@ def process_prompt(item, prompt_config): [{"role": "user", "content": prompt}], local_output_schema, tools=prompt_config.get("tools", None), - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), ) output = self.runner.api.parse_llm_response( - response, + response.response, schema=local_output_schema, tools=prompt_config.get("tools", None), manually_fix_errors=self.manually_fix_errors, )[0] - return output, completion_cost(response) + return output, response.total_cost with ThreadPoolExecutor(max_workers=self.max_threads) as executor: if "prompts" in self.config: diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index b4091865..682b5d70 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -12,7 +12,7 @@ from collections import deque from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import jinja2 import numpy as np @@ -404,7 +404,7 @@ def _cluster_based_sampling( return group_list, 0 clusters, cost = cluster_documents( - group_list, value_sampling, sample_size, self.api + group_list, value_sampling, sample_size, self.runner.api ) sampled_items = [] @@ -444,7 +444,7 @@ def _semantic_similarity_sampling( ) embeddings, cost = get_embeddings_for_clustering( - group_list, value_sampling, self.api + group_list, value_sampling, self.runner.api ) query_response = self.runner.api.gen_embedding(embedding_model, [query_text]) @@ -684,6 +684,15 @@ def _incremental_reduce( return current_output, total_cost + def validation_fn(self, response: Dict[str, Any]): + output = self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + )[0] + if self.runner.api.validate_output(self.config, output, self.console): + return output, True + return output, False + def _increment_fold( self, key: Tuple, @@ -715,29 +724,43 @@ def _increment_fold( output=current_output, reduce_key=dict(zip(self.config["reduce_key"], key)), ) + response = self.runner.api.call_llm( self.config.get("model", self.default_model), "reduce", [{"role": "user", "content": fold_prompt}], self.config["output"]["schema"], scratchpad=scratchpad, - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + bypass_cache=self.config.get("bypass_cache", False), + verbose=self.config.get("verbose", False), ) - folded_output = self.runner.api.parse_llm_response( - response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - folded_output.update(dict(zip(self.config["reduce_key"], key))) - fold_cost = completion_cost(response) end_time = time.time() self._update_fold_time(end_time - start_time) - if self.runner.api.validate_output(self.config, folded_output, self.console): + if response.validated: + folded_output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + + folded_output.update(dict(zip(self.config["reduce_key"], key))) + fold_cost = response.total_cost + return folded_output, fold_cost + return None, fold_cost def _merge_results( @@ -766,20 +789,34 @@ def _merge_results( "merge", [{"role": "user", "content": merge_prompt}], self.config["output"]["schema"], - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + bypass_cache=self.config.get("bypass_cache", False), + verbose=self.config.get("verbose", False), ) - merged_output = self.runner.api.parse_llm_response( - response, self.config["output"]["schema"] - )[0] - merged_output.update(dict(zip(self.config["reduce_key"], key))) - merge_cost = completion_cost(response) + end_time = time.time() self._update_merge_time(end_time - start_time) - if self.runner.api.validate_output(self.config, merged_output, self.console): + if response.validated: + merged_output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + merged_output.update(dict(zip(self.config["reduce_key"], key))) + merge_cost = response.total_cost return merged_output, merge_cost + return None, merge_cost def get_fold_time(self) -> Tuple[float, bool]: @@ -854,41 +891,37 @@ def _batch_reduce( ) item_cost = 0 - if "gleaning" in self.config: - response, gleaning_cost = self.runner.api.call_llm_with_gleaning( - self.config.get("model", self.default_model), - "reduce", - [{"role": "user", "content": prompt}], - self.config["output"]["schema"], - self.config["gleaning"]["validation_prompt"], - self.config["gleaning"]["num_rounds"], - console=self.console, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), - verbose=self.config.get("verbose", False), - ) - item_cost += gleaning_cost - else: - response = self.runner.api.call_llm( - self.config.get("model", self.default_model), - "reduce", - [{"role": "user", "content": prompt}], - self.config["output"]["schema"], - console=self.console, - scratchpad=scratchpad, - timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), - ) + response = self.runner.api.call_llm( + self.config.get("model", self.default_model), + "reduce", + [{"role": "user", "content": prompt}], + self.config["output"]["schema"], + scratchpad=scratchpad, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), + validation_config=( + { + "num_retries": self.num_retries_on_validate_failure, + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + gleaning_config=self.config.get("gleaning", None), + verbose=self.config.get("verbose", False), + ) - item_cost += completion_cost(response) + item_cost += response.total_cost - output = self.runner.api.parse_llm_response( - response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - output.update(dict(zip(self.config["reduce_key"], key))) + if response.validated: + output = self.runner.api.parse_llm_response( + response.response, + schema=self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + output.update(dict(zip(self.config["reduce_key"], key))) - if self.runner.api.validate_output(self.config, output, self.console): return output, item_cost return None, item_cost diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index e5f58121..0184f8b2 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -56,12 +56,13 @@ def compare_pair( {"is_match": "bool"}, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, + bypass_cache=self.config.get("bypass_cache", False), ) output = self.runner.api.parse_llm_response( - response, + response.response, {"is_match": "bool"}, )[0] - return output["is_match"], completion_cost(response) + return output["is_match"], response.total_cost def syntax_check(self) -> None: """ @@ -169,6 +170,15 @@ def syntax_check(self) -> None: if self.config["limit_comparisons"] <= 0: raise ValueError("'limit_comparisons' must be a positive integer") + def validation_fn(self, response: Dict[str, Any]): + output = self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + )[0] + if self.runner.api.validate_output(self.config, output, self.console): + return output, True + return output, False + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """ Executes the resolve operation on the provided dataset. @@ -401,22 +411,28 @@ def process_cluster(cluster): "reduce", [{"role": "user", "content": resolution_prompt}], self.config["output"]["schema"], - console=self.console, timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get( "max_retries_per_timeout", 2 ), + bypass_cache=self.config.get("bypass_cache", False), + validation_config=( + { + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), ) - reduction_output = self.runner.api.parse_llm_response( - reduction_response, - self.config["output"]["schema"], - manually_fix_errors=self.manually_fix_errors, - )[0] - reduction_cost = completion_cost(reduction_response) - - if self.runner.api.validate_output( - self.config, reduction_output, self.console - ): + reduction_cost = reduction_response.total_cost + + if reduction_response.validated: + reduction_output = self.runner.api.parse_llm_response( + reduction_response.response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] return ( [ { diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index ab4c361c..2aceaf5e 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -20,6 +20,7 @@ from rich.console import Console from rich.prompt import Prompt from tqdm import tqdm +from pydantic import BaseModel from docetl.utils import completion_cost, count_tokens import time @@ -36,6 +37,12 @@ cache.close() +class LLMResult(BaseModel): + response: Any + total_cost: float + validated: bool + + def freezeargs(func): """ Decorator to convert mutable dictionary arguments into immutable. @@ -412,8 +419,6 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]: return result - # TODO: optimize this - @freezeargs def _cached_call_llm( self, cache_key: str, @@ -423,7 +428,11 @@ def _cached_call_llm( output_schema: Dict[str, str], tools: Optional[str] = None, scratchpad: Optional[str] = None, - ) -> str: + validation_config: Optional[Dict[str, Any]] = None, + gleaning_config: Optional[Dict[str, Any]] = None, + verbose: bool = False, + bypass_cache: bool = False, + ) -> LLMResult: """ Cached version of the call_llm function. @@ -439,80 +448,163 @@ def _cached_call_llm( output_schema (Dict[str, str]): The output schema dictionary. tools (Optional[str]): The tools to pass to the LLM. scratchpad (Optional[str]): The scratchpad to use for the operation. + validation_config (Optional[Dict[str, Any]]): The validation configuration. + gleaning_config (Optional[Dict[str, Any]]): The gleaning configuration. + verbose (bool): Whether to print verbose output. + bypass_cache (bool): Whether to bypass the cache. Returns: - str: The result from _call_llm_with_cache. + LLMResult: The response from _call_llm_with_cache. """ + total_cost = 0.0 + validated = False with cache as c: - result = c.get(cache_key) - if result is None: - result = self._call_llm_with_cache( + response = c.get(cache_key) + if response is not None and not bypass_cache: + validated = True + else: + response = self._call_llm_with_cache( model, op_type, messages, output_schema, tools, scratchpad ) - # Only set the cache if the result tool calls or output is not empty - if ( - result - and "tool_calls" in dir(result.choices[0].message) - and result.choices[0].message.tool_calls - ): - c.set(cache_key, result) + total_cost += completion_cost(response) + + if gleaning_config: + # Retry gleaning prompt + regular LLM + num_gleaning_rounds = gleaning_config.get("num_rounds", 2) + validator_prompt_template = Template(gleaning_config["prompt"]) + + parsed_output = self.parse_llm_response( + response, output_schema, tools + )[0] + + validator_messages = ( + [ + { + "role": "system", + "content": f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation.", + } + ] + + messages + + [ + {"role": "assistant", "content": json.dumps(parsed_output)}, + ] + ) - return result + for rnd in range(num_gleaning_rounds): + # Prepare validator prompt + validator_prompt = validator_prompt_template.render( + output=parsed_output + ) + self.runner.rate_limiter.try_acquire("llm_call", weight=1) + + validator_response = completion( + model=gleaning_config.get("model", model), + messages=truncate_messages( + validator_messages + + [{"role": "user", "content": validator_prompt}], + model, + ), + response_format={ + "type": "json_schema", + "json_schema": { + "name": "response", + "strict": True, + "schema": { + "type": "object", + "properties": { + "should_refine": {"type": "boolean"}, + "improvements": {"type": "string"}, + }, + "required": ["should_refine", "improvements"], + "additionalProperties": False, + }, + }, + }, + ) + total_cost += completion_cost(validator_response) - def call_llm_with_validation( - self, - messages: List[str], - model: str, - operation_type: str, - schema: Dict[str, str], - llm_call_fn: Callable, - validation_fn: Callable, - val_rule: str, - num_retries: int, - console: Console, - scratchpad: Optional[str] = None, - ) -> Tuple[Any, float, bool]: - num_tries = num_retries + 1 - cost = 0.0 + # Parse the validator response + suggestion = json.loads( + validator_response.choices[0].message.content + ) + if not suggestion["should_refine"]: + break - key = cache_key(model, operation_type, messages, schema, scratchpad) + if verbose: + self.runner.console.log( + f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}" + ) - for i in range(num_tries): - response = llm_call_fn(messages) - if isinstance(response, tuple): - response, curr_cost = response - cost += curr_cost + # Prompt for improvement + improvement_prompt = f"""Based on the validation feedback: - cost += completion_cost(response) + ``` + {suggestion['improvements']} + ``` - parsed_output, result = validation_fn(response) + Please improve your previous response. Ensure that the output adheres to the required schema and addresses any issues raised in the validation.""" + messages.append({"role": "user", "content": improvement_prompt}) - if result: - return parsed_output, cost, True + # Call LLM again + response = self._call_llm_with_cache( + model, op_type, messages, output_schema, tools, scratchpad + ) + total_cost += completion_cost(response) + + validated = True + + # If there's validation, handle it here + elif validation_config: + num_tries = validation_config.get("num_retries", 2) + validation_fn = validation_config.get("validation_fn") + val_rule = validation_config.get("val_rule") + + # Try validation + i = 0 + validation_result = False + while not validation_result and i < num_tries: + parsed_output, validation_result = validation_fn(response) + if validation_result: + validated = True + break + + # Append the validation result to messages + messages.append( + { + "role": "assistant", + "content": json.dumps(parsed_output), + } + ) + messages.append( + { + "role": "user", + "content": f"Your output {parsed_output} failed my validation rule: {str(val_rule)}\n\nPlease try again.", + } + ) + self.runner.console.log( + f"[bold red]Validation failed:[/bold red] {val_rule}\n" + f"\t[yellow]Output:[/yellow] {parsed_output}\n" + f"\t({i + 1}/{num_tries})" + ) + i += 1 - # Remove from cache - with cache as c: - c.delete(key) + response = self._call_llm_with_cache( + model, op_type, messages, output_schema, tools, scratchpad + ) + total_cost += completion_cost(response) - # Append the validation result to messages - messages.append( - { - "role": "assistant", - "content": json.dumps(parsed_output), - } - ) - messages.append( - { - "role": "user", - "content": f"Your output {parsed_output} failed my validation rule: {str(val_rule)}\n\nPlease try again.", - } - ) - console.log( - f"[bold red]Validation failed:[/bold red] {val_rule}\n" - f"\t[yellow]Output:[/yellow] {parsed_output}\n" - f"\tTrying again... ({i + 1}/{num_tries})" - ) + else: + # No validation, so we assume the result is valid + validated = True - return parsed_output, cost, False + # Only set the cache if the result tool calls or output is not empty + if ( + response + and "tool_calls" in dir(response.choices[0].message) + and response.choices[0].message.tool_calls + ): + c.set(cache_key, response) + + return LLMResult(response=response, total_cost=total_cost, validated=validated) def call_llm( self, @@ -522,10 +614,13 @@ def call_llm( output_schema: Dict[str, str], tools: Optional[List[Dict[str, str]]] = None, scratchpad: Optional[str] = None, - console: Console = Console(), timeout_seconds: int = 120, max_retries_per_timeout: int = 2, - ) -> Any: + validation_config: Optional[Dict[str, Any]] = None, + gleaning_config: Optional[Dict[str, Any]] = None, + verbose: bool = False, + bypass_cache: bool = False, + ) -> LLMResult: """ Wrapper function that uses caching for LLM calls. @@ -541,8 +636,9 @@ def call_llm( scratchpad (Optional[str]): The scratchpad to use for the operation. timeout_seconds (int): The timeout for the LLM call. max_retries_per_timeout (int): The maximum number of retries per timeout. + bypass_cache (bool): Whether to bypass the cache. Returns: - str: The result from the cached LLM call. + LLMResult: The result from the cached LLM call. Raises: TimeoutError: If the call times out after retrying. @@ -562,6 +658,10 @@ def call_llm( output_schema, json.dumps(tools) if tools else None, scratchpad, + validation_config=validation_config, + gleaning_config=gleaning_config, + verbose=verbose, + bypass_cache=bypass_cache, ) except RateLimitError: # TODO: this is a really hacky way to handle rate limits @@ -569,18 +669,18 @@ def call_llm( backoff_time = 4 * (2**rate_limited_attempt) # Exponential backoff max_backoff = 120 # Maximum backoff time of 60 seconds sleep_time = min(backoff_time, max_backoff) - console.log( + self.runner.console.log( f"[yellow]Rate limit hit. Retrying in {sleep_time:.2f} seconds...[/yellow]" ) time.sleep(sleep_time) rate_limited_attempt += 1 except TimeoutError: if attempt == max_retries: - console.log( + self.runner.console.log( f"[bold red]LLM call timed out after {max_retries + 1} attempts[/bold red]" ) # TODO: HITL - return {} + return LLMResult(response=None, total_cost=0.0, validated=False) attempt += 1 def _call_llm_with_cache( @@ -591,7 +691,7 @@ def _call_llm_with_cache( output_schema: Dict[str, str], tools: Optional[str] = None, scratchpad: Optional[str] = None, - ) -> str: + ) -> Any: """ Make an LLM call with caching. @@ -680,7 +780,6 @@ def _call_llm_with_cache( Update the 'updated_scratchpad' field in your output with the new scratchpad content. Remember: The scratchpad should contain information necessary for processing future batches, not the final result.""" - messages = json.loads(messages) # Truncate messages if they exceed the model's context length messages = truncate_messages(messages, model) @@ -713,173 +812,6 @@ def _call_llm_with_cache( return response - def call_llm_with_gleaning( - self, - model: str, - op_type: str, - messages: List[Dict[str, str]], - output_schema: Dict[str, str], - validator_prompt_template: str, - num_gleaning_rounds: int, - console: Console = Console(), - timeout_seconds: int = 120, - max_retries_per_timeout: int = 2, - verbose: bool = False, - ) -> Tuple[str, float]: - """ - Call LLM with a gleaning process, including validation and improvement rounds. - - This function performs an initial LLM call, followed by multiple rounds of - validation and improvement based on the validator prompt template. - - Args: - model (str): The model name. - op_type (str): The operation type. - messages (List[Dict[str, str]]): The messages to send to the LLM. - output_schema (Dict[str, str]): The output schema dictionary. - validator_prompt_template (str): Template for the validator prompt. - num_gleaning_rounds (int): Number of gleaning rounds to perform. - timeout_seconds (int): The timeout for the LLM call. - Returns: - Tuple[str, float]: A tuple containing the final LLM response and the total cost. - """ - if not litellm.supports_function_calling(model): - raise ValueError( - f"Model {model} does not support function calling (which we use for structured outputs). Please use a different model." - ) - - props = {key: convert_val(value) for key, value in output_schema.items()} - - parameters = {"type": "object", "properties": props} - parameters["required"] = list(props.keys()) - parameters["additionalProperties"] = False - - # Initial LLM call - response = self.call_llm( - model, - op_type, - messages, - output_schema, - console=console, - timeout_seconds=timeout_seconds, - max_retries_per_timeout=max_retries_per_timeout, - ) - - cost = 0.0 - - # Parse the response - parsed_response = self.parse_llm_response(response, output_schema) - output = parsed_response[0] - - messages = ( - [ - { - "role": "system", - "content": f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation.", - } - ] - + messages - + [ - {"role": "assistant", "content": json.dumps(output)}, - ] - ) - - for rnd in range(num_gleaning_rounds): - cost += completion_cost(response) - - # Prepare validator prompt - validator_template = Template(validator_prompt_template) - validator_prompt = validator_template.render(output=output) - - # Call LLM for validation - self.runner.rate_limiter.try_acquire("llm_call", weight=1) - validator_response = completion( - model=model, - messages=truncate_messages( - messages + [{"role": "user", "content": validator_prompt}], model - ), - response_format={ - "type": "json_schema", - "json_schema": { - "name": "response", - "strict": True, - "schema": { - "type": "object", - "properties": { - "should_refine": {"type": "boolean"}, - "improvements": {"type": "string"}, - }, - "required": ["should_refine", "improvements"], - "additionalProperties": False, - }, - }, - }, - ) - cost += completion_cost(validator_response) - - # Parse the validator response - suggestion = json.loads(validator_response.choices[0].message.content) - if not suggestion["should_refine"]: - break - - if verbose: - console.log( - f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}" - ) - - # Prompt for improvement - improvement_prompt = f"""Based on the validation feedback: - - ``` - {suggestion['improvements']} - ``` - - Please improve your previous response. Ensure that the output adheres to the required schema and addresses any issues raised in the validation.""" - messages.append({"role": "user", "content": improvement_prompt}) - - # Call LLM for improvement - # TODO: support gleaning and tools - self.runner.rate_limiter.try_acquire("llm_call", weight=1) - response = completion( - model=model, - messages=truncate_messages(messages, model), - # response_format={ - # "type": "json_schema", - # "json_schema": { - # "name": "write_output", - # "description": "Write processing output to a database", - # "strict": True, - # "schema": parameters, - # # "additionalProperties": False, - # }, - # }, - tools=[ - { - "type": "function", - "function": { - "name": "send_output", - "description": "Send output back to the user", - "strict": True, - "parameters": parameters, - "additionalProperties": False, - }, - } - ], - tool_choice={"type": "function", "function": {"name": "send_output"}}, - ) - - # Update messages with the new response - messages.append( - { - "role": "assistant", - "content": json.dumps( - self.parse_llm_response(response, output_schema)[0] - ), - } - ) - - return response, cost - def parse_llm_response( self, response: Any, diff --git a/tests/basic/test_basic_filter_split_gather.py b/tests/basic/test_basic_filter_split_gather.py index dcb65694..1c4a3ca0 100644 --- a/tests/basic/test_basic_filter_split_gather.py +++ b/tests/basic/test_basic_filter_split_gather.py @@ -42,7 +42,6 @@ def test_filter_operation( assert len(results) < len(filter_sample_data) assert all(len(result["text"].split()) > 3 for result in results) - assert cost > 0 def test_filter_operation_empty_input( @@ -192,7 +191,6 @@ def test_equijoin_operation( assert len(results) == 2 # Only 2 matches assert all("name" in result and "email" in result for result in results) - assert cost > 0 def test_equijoin_operation_empty_input( diff --git a/tests/basic/test_basic_map.py b/tests/basic/test_basic_map.py index e8c4ce9f..21a21d95 100644 --- a/tests/basic/test_basic_map.py +++ b/tests/basic/test_basic_map.py @@ -23,14 +23,12 @@ def test_map_operation( map_sample_data, ): results, cost = test_map_operation_instance.execute(map_sample_data) - print(results) assert len(results) == len(map_sample_data) assert all("sentiment" in result for result in results) assert all( result["sentiment"] in ["positive", "negative", "neutral"] for result in results ) - assert cost > 0 def test_map_operation_empty_input(map_config, default_model, max_threads, api_wrapper): @@ -48,6 +46,7 @@ def test_map_operation_with_drop_keys( map_sample_data_with_extra_keys, api_wrapper, ): + map_config_with_drop_keys["bypass_cache"] = True operation = MapOperation( api_wrapper, map_config_with_drop_keys, default_model, max_threads ) @@ -55,11 +54,12 @@ def test_map_operation_with_drop_keys( assert len(results) == len(map_sample_data_with_extra_keys) assert all("sentiment" in result for result in results) - assert all("original_sentiment" not in result for result in results) - assert all("to_be_dropped" in result for result in results) + assert all("original_sentiment" in result for result in results) + assert all("to_be_dropped" not in result for result in results) assert all( result["sentiment"] in ["positive", "negative", "neutral"] for result in results ) + assert cost > 0 @@ -95,7 +95,6 @@ def test_map_operation_with_batching( results, cost = operation.execute(map_sample_data) assert len(results) == len(map_sample_data) - assert cost > 0 assert all("sentiment" in result for result in results) assert all( result["sentiment"] in ["positive", "negative", "neutral"] for result in results @@ -128,7 +127,6 @@ def test_map_operation_with_large_max_batch_size( results, cost = operation.execute(map_sample_data) assert len(results) == len(map_sample_data) - assert cost > 0 def test_map_operation_with_word_count_tool( @@ -140,7 +138,6 @@ def test_map_operation_with_word_count_tool( assert len(results) == len(synthetic_data) assert all("word_count" in result for result in results) assert [result["word_count"] for result in results] == [5, 6, 5, 1] - assert cost > 0 # Ensure there was some cost associated with the operation @pytest.fixture @@ -185,8 +182,8 @@ def test_map_operation_with_timeout(simple_map_config, simple_sample_data, api_w operation = MapOperation(api_wrapper, map_config_with_timeout, "gpt-4o-mini", 4) # Execute the operation and expect empty results - with pytest.raises(docetl.operations.utils.InvalidOutputError): - operation.execute(simple_sample_data) + results, cost = operation.execute(simple_sample_data) + assert len(results) == 0 def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wrapper): @@ -215,6 +212,3 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra assert all( any(vs in result["sentiment"] for vs in valid_sentiments) for result in results ) - - # Assert that the operation had a cost - assert cost > 0 diff --git a/tests/basic/test_basic_parallel_map.py b/tests/basic/test_basic_parallel_map.py index 926fe20e..5edcaa2c 100644 --- a/tests/basic/test_basic_parallel_map.py +++ b/tests/basic/test_basic_parallel_map.py @@ -22,6 +22,7 @@ def test_parallel_map_operation( parallel_map_sample_data, api_wrapper, ): + parallel_map_config["bypass_cache"] = True operation = ParallelMapOperation( api_wrapper, parallel_map_config, default_model, max_threads ) diff --git a/tests/basic/test_basic_reduce_resolve.py b/tests/basic/test_basic_reduce_resolve.py index a2f8ab59..a0a93c2f 100644 --- a/tests/basic/test_basic_reduce_resolve.py +++ b/tests/basic/test_basic_reduce_resolve.py @@ -42,6 +42,7 @@ def reduce_sample_data_with_list_key(): def test_reduce_operation( reduce_config, default_model, max_threads, reduce_sample_data, api_wrapper ): + reduce_config["bypass_cache"] = True operation = ReduceOperation(api_wrapper, reduce_config, default_model, max_threads) results, cost = operation.execute(reduce_sample_data) @@ -61,7 +62,6 @@ def test_reduce_operation_with_all_key( results, cost = operation.execute(reduce_sample_data) assert len(results) == 1 - assert cost > 0 def test_reduce_operation_with_list_key( @@ -84,7 +84,6 @@ def test_reduce_operation_with_list_key( and "avg" in result for result in results ) - assert cost > 0 def test_reduce_operation_empty_input( @@ -134,7 +133,6 @@ def test_resolve_operation( distinct_names = set(result["name"] for result in results) assert len(distinct_names) < len(results) - assert cost > 0 def test_resolve_operation_empty_input(resolve_config, max_threads, api_wrapper): diff --git a/tests/basic/test_cluster.py b/tests/basic/test_cluster.py index 3db8424b..23835e8e 100644 --- a/tests/basic/test_cluster.py +++ b/tests/basic/test_cluster.py @@ -70,6 +70,7 @@ def sample_data(): def test_cluster_operation( cluster_config, sample_data, api_wrapper, default_model, max_threads ): + cluster_config["bypass_cache"] = True operation = ClusterOperation( api_wrapper, cluster_config, default_model, max_threads ) diff --git a/tests/conftest.py b/tests/conftest.py index 231775c7..f92c65dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,7 +199,7 @@ def map_config_with_drop_keys(): "prompt": "Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", "output": {"schema": {"sentiment": "string"}}, "model": "gpt-4o-mini", - "drop_keys": ["original_sentiment"], + "drop_keys": ["to_be_dropped"], } diff --git a/tests/test_api.py b/tests/test_api.py index 921c2de9..64cc0243 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -257,7 +257,6 @@ def test_pipeline_execution( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_parallel_map_pipeline( @@ -283,7 +282,6 @@ def test_parallel_map_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_filter_pipeline( @@ -309,7 +307,6 @@ def test_filter_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_reduce_pipeline( @@ -333,7 +330,6 @@ def test_reduce_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_resolve_pipeline( @@ -359,7 +355,6 @@ def test_resolve_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 def test_equijoin_pipeline( @@ -404,4 +399,3 @@ def test_equijoin_pipeline( cost = pipeline.run(max_threads=4) assert isinstance(cost, float) - assert cost > 0 diff --git a/tests/test_eugene.py b/tests/test_eugene.py index dd08f6dc..b9928c17 100644 --- a/tests/test_eugene.py +++ b/tests/test_eugene.py @@ -185,5 +185,3 @@ def test_database_survey_pipeline( assert all("summary" in result for result in summarized_results) total_cost = extract_cost + unnest_cost + resolve_cost + summarize_cost - assert total_cost > 0 - print(total_cost) diff --git a/tests/test_reduce_scale.py b/tests/test_reduce_scale.py index 2d054388..412511a7 100644 --- a/tests/test_reduce_scale.py +++ b/tests/test_reduce_scale.py @@ -90,7 +90,6 @@ def test_reduce_operation( results, cost = operation.execute(reduce_sample_data) assert len(results) == 3, "Should have results for 3 unique categories" - assert cost > 0, "Cost should be greater than 0" for result in results: assert "category" in result, "Each result should have a 'category' key" @@ -112,7 +111,6 @@ def test_reduce_operation_pass_through( results, cost = operation.execute(reduce_sample_data) assert len(results) == 3, "Should have results for 3 unique categories" - assert cost > 0, "Cost should be greater than 0" for result in results: assert "category" in result, "Each result should have a 'category' key" @@ -176,7 +174,6 @@ def test_reduce_operation_non_associative(api_wrapper, default_model, max_thread results, cost = operation.execute(sample_data) assert len(results) == 1, "Should have one result for the 'story' sequence" - assert cost > 0, "Cost should be greater than 0" result = results[0] assert "combined_result" in result, "Result should have a 'combined_result' key" @@ -231,7 +228,6 @@ def test_reduce_operation_persist_intermediates( results, cost = operation.execute(sample_data) assert len(results) == 1, "Should have one result for the 'numbers' group" - assert cost > 0, "Cost should be greater than 0" result = results[0] assert "summary" in result, "Result should have a 'summary' key" diff --git a/tests/test_validation.py b/tests/test_validation.py index ed64af4c..fc6d924b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -31,6 +31,7 @@ def sample_data(): def test_map_operation_with_validation( map_config_with_validation, sample_data, api_wrapper, default_model, max_threads ): + map_config_with_validation["bypass_cache"] = True operation = MapOperation( api_wrapper, map_config_with_validation, default_model, max_threads )