Skip to content

Commit

Permalink
refactor: move validation and gleaning into call llm
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Oct 12, 2024
1 parent 22d3a40 commit c158ae1
Show file tree
Hide file tree
Showing 14 changed files with 362 additions and 475 deletions.
46 changes: 24 additions & 22 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,31 +167,33 @@ def validation_fn(response: Dict[str, Any]):
return output, True
return output, False

output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
response = self.runner.api.call_llm(
model=self.config.get("model", self.default_model),
operation_type="cluster",
schema=self.config["summary_schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"cluster",
messages,
self.config["summary_schema"],
tools=self.config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
op_type="cluster",
messages=[{"role": "user", "content": prompt}],
output_schema=self.config["summary_schema"],
timeout_seconds=self.config.get("timeout", 120),
bypass_cache=self.config.get("bypass_cache", False),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
validation_config=(
{
"num_retries": self.num_retries_on_validate_failure,
"val_rule": self.config.get("validate", []),
"validation_fn": validation_fn,
}
if self.config.get("validate", None)
else None
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
verbose=self.config.get("verbose", False),
)
total_cost += cost

t.update(output)
total_cost += response.total_cost
if response.validated:
output = self.runner.api.parse_llm_response(
response.response,
schema=self.config["summary_schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
t.update(output)

return total_cost
return 0
Expand Down
7 changes: 5 additions & 2 deletions docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def compare_pair(
{"is_match": "bool"},
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
bypass_cache=self.config.get("bypass_cache", False),
)
output = self.runner.api.parse_llm_response(response, {"is_match": "bool"})[0]
return output["is_match"], completion_cost(response)
output = self.runner.api.parse_llm_response(
response.response, {"is_match": "bool"}
)[0]
return output["is_match"], response.total_cost

def syntax_check(self) -> None:
"""
Expand Down
78 changes: 5 additions & 73 deletions docetl/operations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from jinja2 import Template

from docetl.operations.base import BaseOperation
from docetl.operations.map import MapOperation
from docetl.operations.utils import (
RichLoopBar,
)


class FilterOperation(BaseOperation):
class FilterOperation(MapOperation):
def syntax_check(self) -> None:
"""
Checks the configuration of the FilterOperation for required keys and valid structure.
Expand Down Expand Up @@ -110,77 +110,9 @@ def execute(
)
)

if self.status:
self.status.start()

def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)

def validation_fn(response: Dict[str, Any]):
output = self.runner.api.parse_llm_response(
response,
self.config["output"]["schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
for key, value in item.items():
if key not in self.config["output"]["schema"]:
output[key] = value
if self.runner.api.validate_output(self.config, output, self.console):
return output, True
return output, False

output, cost, is_valid = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="filter",
schema=self.config["output"]["schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"filter",
messages,
self.config["output"]["schema"],
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
results, total_cost = super().execute(input_data)

if is_valid:
return output, cost

return None, cost

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = [
executor.submit(_process_filter_item, item) for item in input_data
]
results = []
total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (filter) on all documents",
console=self.console,
)
for i in pbar:
future = futures[i]
result, item_cost = future.result()
total_cost += item_cost
if result is not None:
if is_build:
results.append(result)
else:
if result.get(filter_key, False):
results.append(result)
pbar.update(1)

if self.status:
self.status.start()
# Drop records with filter_key values that are False
results = [result for result in results if result[filter_key]]

return results, total_cost
91 changes: 37 additions & 54 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,59 +153,42 @@ def validation_fn(response: Dict[str, Any]):
return output, False

self.runner.rate_limiter.try_acquire("call", weight=1)
if "gleaning" in self.config:
output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="map",
schema=self.config["output"]["schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm_with_gleaning(
self.config.get("model", self.default_model),
"map",
messages,
self.config["output"]["schema"],
self.config["gleaning"]["validation_prompt"],
self.config["gleaning"]["num_rounds"],
self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
verbose=self.config.get("verbose", False),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
else:
output, cost, success = self.runner.api.call_llm_with_validation(
[{"role": "user", "content": prompt}],
model=self.config.get("model", self.default_model),
operation_type="map",
schema=self.config["output"]["schema"],
llm_call_fn=lambda messages: self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
messages,
self.config["output"]["schema"],
tools=self.config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
num_retries=self.num_retries_on_validate_failure,
console=self.console,
)
llm_result = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
[{"role": "user", "content": prompt}],
self.config["output"]["schema"],
tools=self.config.get("tools", None),
scratchpad=None,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
validation_config=(
{
"num_retries": self.num_retries_on_validate_failure,
"val_rule": self.config.get("validate", []),
"validation_fn": validation_fn,
}
if self.config.get("validate", None)
else None
),
gleaning_config=self.config.get("gleaning", None),
verbose=self.config.get("verbose", False),
bypass_cache=self.config.get("bypass_cache", False),
)

if success:
return output, cost
if llm_result.validated:
# Parse the response
output = self.runner.api.parse_llm_response(
llm_result.response,
schema=self.config["output"]["schema"],
tools=self.config.get("tools", None),
manually_fix_errors=self.manually_fix_errors,
)[0]
# Augment the output with the original item
output = {**item, **output}
return output, llm_result.total_cost

return None, cost
return None, llm_result.total_cost

with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
futures = [executor.submit(_process_map_item, item) for item in input_data]
Expand Down Expand Up @@ -375,17 +358,17 @@ def process_prompt(item, prompt_config):
[{"role": "user", "content": prompt}],
local_output_schema,
tools=prompt_config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
bypass_cache=self.config.get("bypass_cache", False),
)
output = self.runner.api.parse_llm_response(
response,
response.response,
schema=local_output_schema,
tools=prompt_config.get("tools", None),
manually_fix_errors=self.manually_fix_errors,
)[0]
return output, completion_cost(response)
return output, response.total_cost

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
if "prompts" in self.config:
Expand Down
Loading

0 comments on commit c158ae1

Please sign in to comment.