Skip to content

Commit

Permalink
Merge pull request #69 from makrianast/main
Browse files Browse the repository at this point in the history
added batching to matcher.py and unit tests
  • Loading branch information
woodthom2 authored Nov 30, 2024
2 parents 62c8ba2 + da76dfa commit 3631da3
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 3 deletions.
34 changes: 31 additions & 3 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@
)
from harmony.schemas.text_vector import TextVector

import os


def get_batch_size(default=50):
try:
batch_size = int(os.getenv("BATCH_SIZE", default))
return max(batch_size, 0)
except (ValueError, TypeError):
return default
def process_items_in_batches(items, llm_function):
batch_size = get_batch_size()

if batch_size == 0:
return llm_function(items)


batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]

results = []
for batch in batches:
batch_results = llm_function(batch)
results.extend(batch_results)
return results


def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray:
dp = dot(vec1, vec2.T)
Expand Down Expand Up @@ -127,8 +151,11 @@ def create_full_text_vectors(
# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]



# Get vectors for all texts not cached
new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()
new_vectors_list: List = process_items_in_batches(texts_not_cached, vectorisation_function)


# Create a dictionary with new vectors
new_vectors_dict = {}
Expand Down Expand Up @@ -382,7 +409,7 @@ def match_questions_with_catalogue_instruments(

instrument_idx_to_score = {}
for instrument_idx, average_sim in instrument_idx_to_cosine_similarities_average.items():
score = average_sim * (0.1+instrument_idx_to_top_matches_ct.get(instrument_idx, 0))
score = average_sim * (0.1 + instrument_idx_to_top_matches_ct.get(instrument_idx, 0))
instrument_idx_to_score[instrument_idx] = score

# Find the top 10 best instrument idx matches, index 0 containing the best match etc.
Expand Down Expand Up @@ -432,7 +459,8 @@ def match_questions_with_catalogue_instruments(
"info": info,
"num_matched_questions": num_top_match_questions,
"num_ref_instrument_questions": num_questions_in_ref_instrument,
"mean_cosine_similarity": instrument_idx_to_cosine_similarities_average.get(top_catalogue_instrument_idx)
"mean_cosine_similarity": instrument_idx_to_cosine_similarities_average.get(
top_catalogue_instrument_idx)
},
))

Expand Down
84 changes: 84 additions & 0 deletions tests/test_batching_in_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import sys
import os
import unittest
import numpy

sys.path.append("../src")
from unittest import TestCase, mock
from harmony.matching.matcher import get_batch_size
from harmony.matching.matcher import process_items_in_batches


# Mock LLM function
def mock_llm_function(batch):
"""Simulates processing a batch."""
return [f"Processed: {item}" for item in batch]


class TestMatcherBatching(TestCase):

@mock.patch.dict(os.environ, {"BATCH_SIZE": "5"})
def test_batched_processing(self):
"""Test that 10 items are divided into 2 batches of 5 each."""
items = [f"item{i}" for i in range(10)] # 10 items to process
results = process_items_in_batches(items, mock_llm_function)

self.assertEqual(len(results), 10)

expected = [
"Processed: item0", "Processed: item1", "Processed: item2", "Processed: item3", "Processed: item4",
"Processed: item5", "Processed: item6", "Processed: item7", "Processed: item8", "Processed: item9",
]
self.assertEqual(results, expected)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "5"})
def test_large_batch_size(self):
"""Test batch size greater than input size."""
items = [f"item{i}" for i in range(3)] # Only 3 items
results = process_items_in_batches(items, mock_llm_function)

self.assertEqual(len(results), 3)

expected = [
"Processed: item0", "Processed: item1", "Processed: item2",
]
self.assertEqual(results, expected)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "0"})
def test_no_batching(self):
"""Test no batching (all items processed in one batch)."""
items = [f"item{i}" for i in range(10)] # 10 items to process
results = process_items_in_batches(items, mock_llm_function)

self.assertEqual(len(results), 10)

expected = [
"Processed: item0", "Processed: item1", "Processed: item2", "Processed: item3", "Processed: item4",
"Processed: item5", "Processed: item6", "Processed: item7", "Processed: item8", "Processed: item9",
]
self.assertEqual(results, expected)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "-5"})
def test_negative_batch_size(self):
"""Test when BATCH_SIZE is negative, it defaults to 0."""
items = [f"item{i}" for i in range(10)]
results = process_items_in_batches(items, mock_llm_function)
self.assertEqual(len(results), 10)

@mock.patch.dict(os.environ, {}, clear=True)
def test_default_batch_size(self):
"""Test when BATCH_SIZE is not set, it defaults to 50."""
items = [f"item{i}" for i in range(10)]
results = process_items_in_batches(items, mock_llm_function)
self.assertEqual(len(results), 10)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "invalid"})
def test_invalid_batch_size(self):
"""Test when BATCH_SIZE is invalid, it defaults to 50."""
items = [f"item{i}" for i in range(10)]
results = process_items_in_batches(items, mock_llm_function)
self.assertEqual(len(results), 10)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3631da3

Please sign in to comment.