Skip to content

Commit

Permalink
Merge branch 'ucbepic:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
staru09 authored Nov 12, 2024
2 parents 70cb41b + 10925d6 commit 8e3edfb
Show file tree
Hide file tree
Showing 20 changed files with 306 additions and 65 deletions.
1 change: 1 addition & 0 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def validation_fn(response: Dict[str, Any]):
else None
),
verbose=self.config.get("verbose", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
total_cost += response.total_cost
if response.validated:
Expand Down
4 changes: 4 additions & 0 deletions docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
rich_as_completed,
)
from docetl.utils import completion_cost
from pydantic import Field


# Global variables to store shared data
_right_data = None
Expand Down Expand Up @@ -66,6 +68,7 @@ class schema(BaseOperation.schema):
limit_comparisons: Optional[int] = None
blocking_keys: Optional[Dict[str, List[str]]] = None
timeout: Optional[int] = None
litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict)

def compare_pair(
self,
Expand Down Expand Up @@ -101,6 +104,7 @@ def compare_pair(
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
bypass_cache=self.config.get("bypass_cache", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
output = self.runner.api.parse_llm_response(
response.response, {"is_match": "bool"}
Expand Down
1 change: 1 addition & 0 deletions docetl/operations/link_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def validation_fn(response: Dict[str, Any]):
else None
),
verbose=self.config.get("verbose", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)

if response.validated:
Expand Down
6 changes: 5 additions & 1 deletion docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class schema(BaseOperation.schema):
batch_size: Optional[int] = None
clustering_method: Optional[str] = None
batch_prompt: Optional[str] = None
litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict)
@field_validator("drop_keys")
def validate_drop_keys(cls, v):
if isinstance(v, str):
Expand Down Expand Up @@ -213,6 +214,7 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
verbose=self.config.get("verbose", False),
bypass_cache=self.config.get("bypass_cache", False),
initial_result=initial_result,
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)

if llm_result.validated:
Expand Down Expand Up @@ -249,7 +251,8 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]:
verbose=self.config.get("verbose", False),
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
bypass_cache=self.config.get("bypass_cache", False)
bypass_cache=self.config.get("bypass_cache", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
total_cost += llm_result.total_cost

Expand Down Expand Up @@ -460,6 +463,7 @@ def process_prompt(item, prompt_config):
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
bypass_cache=self.config.get("bypass_cache", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
output = self.runner.api.parse_llm_response(
response.response,
Expand Down
17 changes: 15 additions & 2 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from docetl.operations.utils import rich_as_completed
from docetl.utils import completion_cost
from pydantic import Field


class ReduceOperation(BaseOperation):
Expand All @@ -51,7 +52,8 @@ class schema(BaseOperation.schema):
value_sampling: Optional[Dict[str, Any]] = None
verbose: Optional[bool] = None
timeout: Optional[int] = None

litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict)

def __init__(self, *args, **kwargs):
"""
Initialize the ReduceOperation.
Expand Down Expand Up @@ -323,7 +325,15 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
else:
# Group the input data by the reduce key(s) while maintaining original order
def get_group_key(item):
return tuple(item[key] for key in reduce_keys)
key_values = []
for key in reduce_keys:
value = item[key]
# Special handling for list-type values
if isinstance(value, list):
key_values.append(tuple(sorted(value))) # Convert list to sorted tuple
else:
key_values.append(value)
return tuple(key_values)

grouped_data = {}
for item in input_data:
Expand Down Expand Up @@ -789,6 +799,7 @@ def _increment_fold(
),
bypass_cache=self.config.get("bypass_cache", False),
verbose=self.config.get("verbose", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)

end_time = time.time()
Expand Down Expand Up @@ -847,6 +858,7 @@ def _merge_results(
),
bypass_cache=self.config.get("bypass_cache", False),
verbose=self.config.get("verbose", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)

end_time = time.time()
Expand Down Expand Up @@ -956,6 +968,7 @@ def _batch_reduce(
),
gleaning_config=self.config.get("gleaning", None),
verbose=self.config.get("verbose", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)

item_cost += response.total_cost
Expand Down
5 changes: 5 additions & 0 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from docetl.operations.utils import RichLoopBar, rich_as_completed
from docetl.utils import completion_cost, extract_jinja_variables

from pydantic import Field


def find_cluster(item, cluster_map):
while item != cluster_map[item]:
Expand All @@ -42,6 +44,7 @@ class schema(BaseOperation.schema):
limit_comparisons: Optional[int] = None
optimize: Optional[bool] = None
timeout: Optional[int] = None
litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict)

def compare_pair(
self,
Expand Down Expand Up @@ -84,6 +87,7 @@ def compare_pair(
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
bypass_cache=self.config.get("bypass_cache", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
output = self.runner.api.parse_llm_response(
response.response,
Expand Down Expand Up @@ -545,6 +549,7 @@ def process_cluster(cluster):
if self.config.get("validate", None)
else None
),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
reduction_cost = reduction_response.total_cost

Expand Down
16 changes: 12 additions & 4 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,13 @@ def call_llm_batch(
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
bypass_cache: bool = False,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> LLMResult:
# Turn the output schema into a list of schemas
output_schema = convert_dict_schema_to_list_schema(output_schema)

# Invoke the LLM call
return self.call_llm(model, op_type,messages, output_schema, verbose=verbose, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=bypass_cache)
return self.call_llm(model, op_type,messages, output_schema, verbose=verbose, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=bypass_cache, litellm_completion_kwargs=litellm_completion_kwargs)


def _cached_call_llm(
Expand All @@ -456,6 +457,7 @@ def _cached_call_llm(
verbose: bool = False,
bypass_cache: bool = False,
initial_result: Optional[Any] = None,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> LLMResult:
"""
Cached version of the call_llm function.
Expand Down Expand Up @@ -489,7 +491,7 @@ def _cached_call_llm(
else:
if not initial_result:
response = self._call_llm_with_cache(
model, op_type, messages, output_schema, tools, scratchpad
model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs
)
total_cost += completion_cost(response)
else:
Expand Down Expand Up @@ -556,6 +558,7 @@ def _cached_call_llm(
}
],
tool_choice="required",
**litellm_completion_kwargs,
)
total_cost += completion_cost(validator_response)

Expand Down Expand Up @@ -583,7 +586,7 @@ def _cached_call_llm(

# Call LLM again
response = self._call_llm_with_cache(
model, op_type, messages, output_schema, tools, scratchpad
model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs
)
parsed_output = self.parse_llm_response(
response, output_schema, tools
Expand Down Expand Up @@ -633,7 +636,7 @@ def _cached_call_llm(
i += 1

response = self._call_llm_with_cache(
model, op_type, messages, output_schema, tools, scratchpad
model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs
)
total_cost += completion_cost(response)

Expand Down Expand Up @@ -662,6 +665,7 @@ def call_llm(
verbose: bool = False,
bypass_cache: bool = False,
initial_result: Optional[Any] = None,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> LLMResult:
"""
Wrapper function that uses caching for LLM calls.
Expand Down Expand Up @@ -706,6 +710,7 @@ def call_llm(
verbose=verbose,
bypass_cache=bypass_cache,
initial_result=initial_result,
litellm_completion_kwargs=litellm_completion_kwargs,
)
except RateLimitError:
# TODO: this is a really hacky way to handle rate limits
Expand Down Expand Up @@ -735,6 +740,7 @@ def _call_llm_with_cache(
output_schema: Dict[str, str],
tools: Optional[str] = None,
scratchpad: Optional[str] = None,
litellm_completion_kwargs: Dict[str, Any] = {},
) -> Any:
"""
Make an LLM call with caching.
Expand Down Expand Up @@ -841,6 +847,7 @@ def _call_llm_with_cache(
+ messages,
tools=tools,
tool_choice=tool_choice,
**litellm_completion_kwargs,
)
else:
response = completion(
Expand All @@ -852,6 +859,7 @@ def _call_llm_with_cache(
},
]
+ messages,
**litellm_completion_kwargs,
)


Expand Down
9 changes: 8 additions & 1 deletion docs/concepts/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@ LLM-based operators have additional attributes:
- `prompt`: A Jinja2 template that defines the instruction for the language model.
- `output`: Specifies the schema for the output from the LLM call.
- `model` (optional): Allows specifying a different model from the pipeline default.
- `litellm_completion_kwargs` (optional): Additional parameters to pass to LiteLLM completion calls.

DocETL uses [LiteLLM](https://docs.litellm.ai) to execute all LLM calls, providing support for 100+ LLM providers including OpenAI, Anthropic, Azure, and more. You can pass any LiteLLM completion arguments using the `litellm_completion_kwargs` field.

Example:

```yaml
- name: extract_insights
type: map
model: gpt-4o
model: gpt-4o-mini
litellm_completion_kwargs:
max_tokens: 500 # limit response length
temperature: 0.7 # control randomness
top_p: 0.9 # nucleus sampling parameter
prompt: |
Analyze the following user interaction log:
{{ input.log }}
Expand Down
1 change: 1 addition & 0 deletions docs/operators/cluster.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,4 @@ and a description, and groups them into a tree of categories.
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `sample` | Number of items to sample for this operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
1 change: 1 addition & 0 deletions docs/operators/map.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ This example demonstrates how the Map operation can transform long, unstructured
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |

Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters.

Expand Down
1 change: 1 addition & 0 deletions docs/operators/parallel-map.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Each prompt configuration in the `prompts` list should contain:
| `sample` | Number of samples to use for the operation | Processes all data |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |

??? question "Why use Parallel Map instead of multiple Map operations?"

Expand Down
1 change: 1 addition & 0 deletions docs/operators/reduce.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ This Reduce operation processes customer feedback grouped by department:
| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |

## Advanced Features

Expand Down
4 changes: 3 additions & 1 deletion docs/operators/resolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `limit_comparisons` | Maximum number of comparisons to perform | None |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `sample` | Number of samples to use for the operation | None |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |

## Best Practices

1. **Anticipate Resolve Needs**: If you anticipate needing a Resolve operation and want to control the prompts, create it in your pipeline and let the optimizer find the appropriate blocking rules and thresholds.
Expand Down
32 changes: 32 additions & 0 deletions tests/basic/test_basic_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,35 @@ def test_map_operation_with_larger_batch(simple_map_config, map_sample_data_with
assert all(
any(vs in result["sentiment"] for vs in valid_sentiments) for result in results
)

def test_map_operation_with_max_tokens(simple_map_config, map_sample_data, api_wrapper):
# Add litellm_completion_kwargs configuration with max_tokens
map_config_with_max_tokens = {
**simple_map_config,
"litellm_completion_kwargs": {
"max_tokens": 10
},
"bypass_cache": True
}

operation = MapOperation(api_wrapper, map_config_with_max_tokens, "gpt-4o-mini", 4)

# Execute the operation
results, cost = operation.execute(map_sample_data)

# Assert that we have results for all input items
assert len(results) == len(map_sample_data)

# Check that all results have a sentiment
assert all("sentiment" in result for result in results)

# Verify that all sentiments are valid
valid_sentiments = ["positive", "negative", "neutral"]
assert all(
any(vs in result["sentiment"] for vs in valid_sentiments) for result in results
)

# Since we limited max_tokens to 10, each response should be relatively short
# The sentiment field should contain just the sentiment value without much extra text
assert all(len(result["sentiment"]) <= 20 for result in results)

Loading

0 comments on commit 8e3edfb

Please sign in to comment.