Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Enhancements: Batch Processing, Crosswalk Table, and Matching Toggle" #73

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/crosswalk_table.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Item 1,Item 2,Similarity Score
What is your age?,What is your age?,1.0
How old are you?,How old are you?,1.0
What is your name?,What is your name?,1.0
202 changes: 105 additions & 97 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
MIT License

Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk).
Project: Harmony (https://harmonydata.ac.uk)
Maintainer: Thomas Wood (https://fastdatascience.com)
Copyright (c) 2023 Ulster University
Project: Harmony
Maintainer: Thomas Wood

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,50 +28,24 @@
import heapq
from collections import Counter, OrderedDict
from typing import List, Callable
import pandas as pd # ADDED for Task 3

import numpy as np
from numpy import dot, matmul, ndarray, matrix
from numpy.linalg import norm
import os

from harmony.matching.negator import negate
from harmony.schemas.catalogue_instrument import CatalogueInstrument
from harmony.schemas.catalogue_question import CatalogueQuestion
from harmony.schemas.requests.text import (
Instrument,
Question,
)
from harmony.schemas.requests.text import Instrument, Question
from harmony.schemas.text_vector import TextVector

import os


def get_batch_size(default=50):
try:
batch_size = int(os.getenv("BATCH_SIZE", default))
return max(batch_size, 0)
except (ValueError, TypeError):
return default
def process_items_in_batches(items, llm_function):
batch_size = get_batch_size()

if batch_size == 0:
return llm_function(items)


batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]

results = []
for batch in batches:
batch_results = llm_function(batch)
results.extend(batch_results)
return results


def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray:
dp = dot(vec1, vec2.T)
m1 = matrix(norm(vec1, axis=1))
m2 = matrix(norm(vec2.T, axis=0))

return np.asarray(dp / matmul(m1.T, m2))


Expand Down Expand Up @@ -104,115 +78,145 @@ def process_questions(questions, texts_cached_vectors):
return text_vectors


def vectorise_texts(text_vectors, vectorisation_function):
for index, text_dict in enumerate(text_vectors):
if not text_dict.vector:
text_vectors[index].vector = vectorisation_function([text_dict.text]).tolist()[0]
return text_vectors


def vectors_pos_neg(text_vectors):
vectors_pos = np.array(
[
x.vector
for x in text_vectors
if (x.is_negated is False and x.is_query is False)
]
)

# Create numpy array of negated texts vectors
vectors_neg = np.array(
[
x.vector
for x in text_vectors
if (x.is_negated is True and x.is_query is False)
]
)
return vectors_pos, vectors_neg


def create_full_text_vectors(
all_questions: List[str],
query: str | None,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, list[float]],
all_questions: List[str],
query: str | None,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, list[float]],
) -> tuple[List[TextVector], dict]:
"""
Create full text vectors.
"""

# Create a list of text vectors
text_vectors = process_questions(all_questions, texts_cached_vectors)

# Add query
if query:
text_vectors = add_text_to_vec(query, texts_cached_vectors, text_vectors, False, True)

# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]



# Get vectors for all texts not cached
new_vectors_list: List = process_items_in_batches(texts_not_cached, vectorisation_function)

new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()

# Create a dictionary with new vectors
new_vectors_dict = {}
for vector, text in zip(new_vectors_list, texts_not_cached):
new_vectors_dict[text] = vector

# Add new vectors to all_texts
for index, text_dict in enumerate(text_vectors):
if not text_dict.vector:
text_vectors[index].vector = new_vectors_list.pop(0)

return text_vectors, new_vectors_dict


# ADDED: Crosswalk Table Function
def generate_crosswalk_table(matches, similarity_scores):
"""
Generate a crosswalk table from matched item pairs and their similarity scores.

Args:
matches (list of tuple): List of matched item pairs as (item1, item2).
similarity_scores (list of float): List of similarity scores for each pair.

Returns:
pd.DataFrame: A DataFrame representing the crosswalk table.
"""
if len(matches) != len(similarity_scores):
raise ValueError("The length of matches and similarity_scores must be the same.")

crosswalk_table = pd.DataFrame({
"Item 1": [pair[0] for pair in matches],
"Item 2": [pair[1] for pair in matches],
"Similarity Score": similarity_scores
})
return crosswalk_table


# MODIFIED: match_instruments_with_catalogue_instruments with new parameters
def match_instruments_with_catalogue_instruments(
instruments: List[Instrument],
catalogue_data: dict,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, List[float]],
instruments: List[Instrument],
catalogue_data: dict,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, List[float]],
within_instrument=True, # ADDED
save_crosswalk=True # ADDED
) -> tuple[List[Instrument], List[CatalogueInstrument]]:
"""
Match instruments with catalogue instruments.

:param instruments: The instruments.
:param catalogue_data: The catalogue data.
:param vectorisation_function: A function to vectorize a text.
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector).
:return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog.
Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the instruments.
Match instruments with catalogue instruments, with optional within-instrument matching
and crosswalk table generation.

Args:
instruments (list): List of instruments to match.
catalogue_data (dict): Catalogue data for matching.
vectorisation_function (callable): Function to vectorize text data.
texts_cached_vectors (dict): Cached vectors for efficiency.
within_instrument (bool): Whether to allow within-instrument matches. # ADDED
save_crosswalk (bool): Whether to save the crosswalk table. # ADDED

Returns:
list, list: Matched item pairs and their similarity scores.
"""
matches = []
similarity_scores = []

# Gather all questions
all_questions: List[str] = []
for instrument in instruments:
all_questions.extend([q.question_text for q in instrument.questions])
all_questions = list(set(all_questions))

# Create text vectors for all questions in all the uploaded instruments
# Create text vectors
all_instruments_text_vectors, _ = create_full_text_vectors(
all_questions=all_questions,
query=None,
vectorisation_function=vectorisation_function,
texts_cached_vectors=texts_cached_vectors,
)

# For each instrument, find the best instrument matches for it in the catalogue
# Matching logic
for instrument in instruments:
instrument.closest_catalogue_instrument_matches = (
match_questions_with_catalogue_instruments(
questions=instrument.questions,
catalogue_data=catalogue_data,
all_instruments_text_vectors=all_instruments_text_vectors,
questions_are_from_one_instrument=True,
)
for question in instrument.questions:
for catalogue_question in catalogue_data.get("questions", []):
if not within_instrument and question.instrument_id == catalogue_question.instrument_id:
continue

# Vectorize question text if vector not already present
vector1 = texts_cached_vectors.get(question.question_text)
if vector1 is None:
vector1 = vectorisation_function([question.question_text])[0]
texts_cached_vectors[question.question_text] = vector1

vector2 = texts_cached_vectors.get(catalogue_question.question_text)
if vector2 is None:
vector2 = vectorisation_function([catalogue_question.question_text])[0]
texts_cached_vectors[catalogue_question.question_text] = vector2

texts_cached_vectors[question.question_text] = vector1
texts_cached_vectors[catalogue_question.question_text] = vector2

# Calculate similarity
score = cosine_similarity(
np.array([vector1]),
np.array([vector2])
)[0][0]

if score > 0.8:
matches.append((question.question_text, catalogue_question.question_text))
similarity_scores.append(score)


# Assign matches to the instrument
instrument.closest_catalogue_instrument_matches = match_questions_with_catalogue_instruments(
questions=instrument.questions,
catalogue_data=catalogue_data,
all_instruments_text_vectors=all_instruments_text_vectors,
questions_are_from_one_instrument=True,
)

# Gather all questions from all instruments and find the best instrument matches in the catalogue
# Save crosswalk table if required
if save_crosswalk:
crosswalk_table = generate_crosswalk_table(matches, similarity_scores)
crosswalk_table.to_csv("crosswalk_table.csv", index=False)
print("Crosswalk table saved as 'crosswalk_table.csv'")

# Find matches across all instruments
all_instrument_questions: List[Question] = []
for instrument in instruments:
all_instrument_questions.extend(instrument.questions)
Expand All @@ -226,6 +230,8 @@ def match_instruments_with_catalogue_instruments(
return instruments, closest_catalogue_instrument_matches




def match_questions_with_catalogue_instruments(
questions: List[Question],
catalogue_data: dict,
Expand Down Expand Up @@ -667,3 +673,5 @@ def match_instruments_with_function(
query_similarity,
new_vectors_dict
)


46 changes: 46 additions & 0 deletions src/test_batch_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
import numpy as np
from harmony.matching.matcher import batch_process, vectorize_items_with_batching

class TestBatchProcessing(unittest.TestCase):

def setUp(self):
self.vectorization_function = lambda texts: np.array([[len(text)] for text in texts])

def test_batch_process(self):
items = ["item1", "item2", "item3", "item4", "item5"]
batch_size = 2
batches = batch_process(items, batch_size)
expected_batches = [["item1", "item2"], ["item3", "item4"], ["item5"]]
self.assertEqual(batches, expected_batches)

def test_vectorize_items_with_batching(self):
items = ["short", "medium length", "a bit longer", "longest item in the list"]
batch_size = 2
vectors = vectorize_items_with_batching(items, self.vectorization_function, batch_size)
expected_vectors = np.array([[5], [13], [12], [24]])
np.testing.assert_array_equal(vectors, expected_vectors)

def test_edge_case_single_item(self):
items = ["single item"]
batch_size = 2
vectors = vectorize_items_with_batching(items, self.vectorization_function, batch_size)
expected_vectors = np.array([[11]])
np.testing.assert_array_equal(vectors, expected_vectors)

def test_edge_case_empty_list(self):
items = []
batch_size = 2
vectors = vectorize_items_with_batching(items, self.vectorization_function, batch_size)
expected_vectors = np.array([])
np.testing.assert_array_equal(vectors, expected_vectors)

def test_large_batch_size(self):
items = ["item1", "item2", "item3"]
batch_size = 10
batches = batch_process(items, batch_size)
expected_batches = [["item1", "item2", "item3"]]
self.assertEqual(batches, expected_batches)

if __name__ == "__main__":
unittest.main()
Loading