Skip to content

Commit

Permalink
Add NLPLinker class and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lizgzil committed Aug 6, 2024
1 parent 1c5cb4b commit 7361002
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 14 deletions.
264 changes: 256 additions & 8 deletions nlp_link/linker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,260 @@
"""
Class to link two datasets.
Example usage:
from nlp_link.linker import NLPLinker
nlp_link = NLPLinker()
# dict inputs
comparison_data = {'a': 'cats', 'b': 'dogs', 'd': 'rats', 'e': 'birds'}
input_data = {'x': 'owls', 'y': 'feline', 'z': 'doggies', 'za': 'dogs', 'zb': 'chair'}
nlp_link.load(comparison_data)
matches = nlp_link.link_dataset(input_data)
# Top match output
print(matches)
# list inputs
comparison_data = ['cats', 'dogs', 'rats', 'birds']
input_data = ['owls', 'feline', 'doggies', 'dogs','chair']
nlp_link.load(comparison_data)
matches = nlp_link.link_dataset(input_data)
# Top match output
print(matches)
"""

from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
import random

from typing import Union, Optional
import logging

from nlp_link.utils import chunk_list

logger = logging.getLogger(__name__)

# TO DO: cosine or euclidean?


class NLPLinker(object):
"""docstring for NLPLinker"""

def __init__(self, batch_size=32, embed_chunk_size=500, match_chunk_size=10000):
super(NLPLinker, self).__init__()
self.batch_size = batch_size
self.embed_chunk_size = embed_chunk_size
self.match_chunk_size = match_chunk_size
## Cleaning?

def _process_dataset(
self,
input_data: Union[list, dict, pd.DataFrame],
id_column: Optional[str] = None,
text_column: Optional[str] = None,
) -> dict:
"""Check and process a dataset according to the input type
Args:
input_data (Union[list, dict, pd.DataFrame])
A list of texts or a dictionary of texts where the key is the unique id.
If a list is given then a unique id will be assigned with the index order.
Returns:
dict: key is the id and the value is the text
"""

if isinstance(input_data, list):
return {ix: text for ix, text in enumerate(input_data)}
elif isinstance(input_data, dict):
return input_data
elif isinstance(input_data, pd.DataFrame):
try:
return dict(zip(input_data[id_column], input_data[text_column]))
except:
logger.warning(
"Input is a dataframe, please specify id_column and text_column"
)
else:
logger.warning(
"The input_data input must be a dictionary, a list or pandas dataframe"
)

if not isinstance(input_data[0], str):
logger.warning(
"The input_data input must be a list of texts, or a dictionary where the values are texts"
)

def load(
self,
comparison_data: Union[list, dict],
):
"""
Load the embedding model and embed the comparison dataset
Args:
comparison_data (Union[list, dict]): The comparison texts to find links to.
A list of texts or a dictionary of texts where the key is the unique id.
If a list is given then a unique id will be assigned with the index order.
"""
logger.info("Loading model")
self.bert_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
self.bert_model.max_seq_length = 512

self.comparison_data = self._process_dataset(comparison_data)
self.comparison_data_texts = list(self.comparison_data.values())
self.comparison_data_ids = list(self.comparison_data.keys())

self.comparison_embeddings = self._get_embeddings(self.comparison_data_texts)

def _get_embeddings(self, text_list: list) -> np.array:
"""
Get embeddings for a list of texts
Args:
text_list (list): A lists of texts
Returns:
np.array: The embeddings for the input list of texts
"""

logger.info(
f"Finding embeddings for {len(text_list)} texts chunked into {round(len(text_list)/self.embed_chunk_size)} chunks"
)
all_embeddings = []
for batch_texts in tqdm(chunk_list(text_list, self.embed_chunk_size)):
all_embeddings.append(
self.bert_model.encode(
np.array(batch_texts), batch_size=self.batch_size
)
)
all_embeddings = np.concatenate(all_embeddings)

return all_embeddings

def get_matches(
self,
input_data_ids: list,
input_embeddings: np.array,
comparison_data_ids: list,
comparison_embeddings: np.array,
top_n: int,
drop_most_similar: bool = False,
) -> dict:
"""
Find top matches across two datasets using their embeddings.
Args:
input_data_ids (list): The ids of the input texts.
input_embeddings (np.array): Embeddings for the input texts.
comparison_data_ids (list): The ids of the comparison texts.
comparison_embeddings (np.array): Embeddings for the comparison texts.
top_n (int): The number of top links to return in the output.
drop_most_similar (bool, default = False): Whether to not output the most similar match, this would be set to True if you are matching a list with itself.
Returns:
dict: The top matches for each input id.
"""

logger.info(
f"Finding the top dataset matches for {len(input_data_ids)} input texts chunked into {round(len(input_data_ids)/self.match_chunk_size)}"
)

if drop_most_similar:
top_n = top_n + 1
start_n = 1
else:
start_n = 0

# We chunk up comparisons otherwise it can crash
matches_topn = {}
for batch_indices in tqdm(
chunk_list(range(len(input_data_ids)), n_chunks=self.match_chunk_size)
):
batch_input_ids = [input_data_ids[i] for i in batch_indices]
batch_input_embeddings = [input_embeddings[i] for i in batch_indices]

batch_similarities = cosine_similarity(
batch_input_embeddings, comparison_embeddings
)

# Top links for each input text
for input_ix, similarities in enumerate(batch_similarities):
top_links = []
for comparison_ix in np.flip(np.argsort(similarities))[start_n:top_n]:
# comparison data id + cosine similarity score
top_links.append(
[
comparison_data_ids[comparison_ix],
similarities[comparison_ix],
]
)
matches_topn[batch_input_ids[input_ix]] = top_links
return matches_topn

def link_dataset(
self,
input_data: Union[list, dict],
top_n: int = 3,
format_output: bool = True,
drop_most_similar: bool = False,
) -> dict:
"""
Link a dataset to the comparison dataset.
Args:
input_data (Union[list, dict]): The main dictionary to be linked to texts in the loaded comparison_data.
A list of texts or a dictionary of texts where the key is the unique id.
If a list is given then a unique id will be assigned with the index order.
top_n (int, default = 3): The number of top links to return in the output.
format_output (bool, default = True): If you'd like the output to be formatted to include the texts of
the matched datasets or not (will just give the indices).
drop_most_similar (bool, default = False): Whether to not output the most similar match, this would be set to True if you are matching a list with itself.
Returns:
dict: The keys are the ids of the input_data and the values are a list of lists of the top_n most similar
ids from the comparison_data and a probability score.
e.g. {'x': [['a', 0.75], ['c', 0.7]], 'y': [...]}
"""

try:
logger.info(
f"Comparing {len(input_data)} input texts to {len(self.comparison_embeddings)} comparison texts"
)
except:
logger.warning(
"self.comparison_embeddings does not exist - you may have not run load()"
)

input_data = self._process_dataset(input_data)
input_data_texts = list(input_data.values())
input_data_ids = list(input_data.keys())

input_embeddings = self._get_embeddings(input_data_texts)

def link_lists(list_1, list_2):
"""
Mock linker
"""
list_1_index = list(range(len(list_1)))
list_2_index = list(range(len(list_2)))
self.matches_topn = self.get_matches(
input_data_ids,
input_embeddings,
self.comparison_data_ids,
self.comparison_embeddings,
top_n,
drop_most_similar,
)

return [(i, random.choice(list_1_index)) for i in list_2_index]
if format_output:
# Format the output into a user friendly pandas format with the top link only
df_output = pd.DataFrame(
[
{
"input_id": input_id,
"input_text": input_data[input_id],
"link_id": link_data[0][0],
"link_text": self.comparison_data[link_data[0][0]],
"similarity": link_data[0][1],
}
for input_id, link_data in self.matches_topn.items()
]
)
return df_output
else:
return self.matches_topn
74 changes: 68 additions & 6 deletions tests/test_linker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,72 @@
from nlp_link.linker import link_lists
from nlp_link.linker import NLPLinker

import numpy as np

def test_link_lists():

list_1 = ["dog", "cat"]
list_2 = ["kitten", "puppy"]
linked = link_lists(list_1, list_2)
def test_NLPLinker_dict_input():

assert len(linked) == len(list_1)
nlp_link = NLPLinker()

comparison_data = {"a": "cats", "b": "dogs", "c": "rats", "d": "birds"}
input_data = {
"x": "owls",
"y": "feline",
"z": "doggies",
"za": "dogs",
"zb": "chair",
}
nlp_link.load(comparison_data)
matches = nlp_link.link_dataset(input_data)

assert len(matches) == len(input_data)
assert len(set(matches["link_id"]).difference(set(comparison_data.keys()))) == 0


def test_NLPLinker_list_input():

nlp_link = NLPLinker()

comparison_data = ["cats", "dogs", "rats", "birds"]
input_data = ["owls", "feline", "doggies", "dogs", "chair"]
nlp_link.load(comparison_data)
matches = nlp_link.link_dataset(input_data)

assert len(matches) == len(input_data)
assert (
len(set(matches["link_id"]).difference(set(range(len(comparison_data))))) == 0
)


def test_get_matches():

nlp_link = NLPLinker()

matches_topn = nlp_link.get_matches(
input_data_ids=["x", "y", "z"],
input_embeddings=np.array(
[[0.1, 0.13, 0.14], [0.12, 0.18, 0.15], [0.5, 0.9, 0.91]]
),
comparison_data_ids=["a", "b"],
comparison_embeddings=np.array([[0.51, 0.99, 0.9], [0.1, 0.13, 0.14]]),
top_n=1,
)

assert matches_topn["x"][0][0] == "b"
assert matches_topn["y"][0][0] == "b"
assert matches_topn["z"][0][0] == "a"


def test_same_input():

nlp_link = NLPLinker()

comparison_data = {"a": "cats", "b": "dogs", "c": "rats", "d": "birds"}
input_data = comparison_data
nlp_link.load(comparison_data)
matches = nlp_link.link_dataset(input_data, drop_most_similar=False)

assert all(matches["input_id"] == matches["link_id"])

matches = nlp_link.link_dataset(input_data, drop_most_similar=True)

assert all(matches["input_id"] != matches["link_id"])

0 comments on commit 7361002

Please sign in to comment.