Skip to content

Commit

Permalink
refactor: move validation and gleaning into call llm
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Oct 12, 2024
1 parent c158ae1 commit 0fdeb8f
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _cluster_based_sampling(
return group_list, 0

clusters, cost = cluster_documents(
group_list, value_sampling, sample_size, self.api
group_list, value_sampling, sample_size, self.runner.api
)

sampled_items = []
Expand Down Expand Up @@ -444,7 +444,7 @@ def _semantic_similarity_sampling(
)

embeddings, cost = get_embeddings_for_clustering(
group_list, value_sampling, self.api
group_list, value_sampling, self.runner.api
)

query_response = self.runner.api.gen_embedding(embedding_model, [query_text])
Expand Down
2 changes: 0 additions & 2 deletions tests/test_eugene.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,3 @@ def test_database_survey_pipeline(
assert all("summary" in result for result in summarized_results)

total_cost = extract_cost + unnest_cost + resolve_cost + summarize_cost
assert total_cost > 0
print(total_cost)
4 changes: 0 additions & 4 deletions tests/test_reduce_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def test_reduce_operation(
results, cost = operation.execute(reduce_sample_data)

assert len(results) == 3, "Should have results for 3 unique categories"
assert cost > 0, "Cost should be greater than 0"

for result in results:
assert "category" in result, "Each result should have a 'category' key"
Expand All @@ -112,7 +111,6 @@ def test_reduce_operation_pass_through(
results, cost = operation.execute(reduce_sample_data)

assert len(results) == 3, "Should have results for 3 unique categories"
assert cost > 0, "Cost should be greater than 0"

for result in results:
assert "category" in result, "Each result should have a 'category' key"
Expand Down Expand Up @@ -176,7 +174,6 @@ def test_reduce_operation_non_associative(api_wrapper, default_model, max_thread
results, cost = operation.execute(sample_data)

assert len(results) == 1, "Should have one result for the 'story' sequence"
assert cost > 0, "Cost should be greater than 0"

result = results[0]
assert "combined_result" in result, "Result should have a 'combined_result' key"
Expand Down Expand Up @@ -231,7 +228,6 @@ def test_reduce_operation_persist_intermediates(
results, cost = operation.execute(sample_data)

assert len(results) == 1, "Should have one result for the 'numbers' group"
assert cost > 0, "Cost should be greater than 0"

result = results[0]
assert "summary" in result, "Result should have a 'summary' key"
Expand Down
1 change: 1 addition & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def sample_data():
def test_map_operation_with_validation(
map_config_with_validation, sample_data, api_wrapper, default_model, max_threads
):
map_config_with_validation["bypass_cache"] = True
operation = MapOperation(
api_wrapper, map_config_with_validation, default_model, max_threads
)
Expand Down

0 comments on commit 0fdeb8f

Please sign in to comment.