-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from nestauk/add-real-linker
Add NLPLinker class and tests
- Loading branch information
Showing
4 changed files
with
342 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,264 @@ | ||
""" | ||
Class to link two datasets. | ||
Example usage: | ||
from nlp_link.linker import NLPLinker | ||
nlp_link = NLPLinker() | ||
# dict inputs | ||
comparison_data = {'a': 'cats', 'b': 'dogs', 'd': 'rats', 'e': 'birds'} | ||
input_data = {'x': 'owls', 'y': 'feline', 'z': 'doggies', 'za': 'dogs', 'zb': 'chair'} | ||
nlp_link.load(comparison_data) | ||
matches = nlp_link.link_dataset(input_data) | ||
# Top match output | ||
print(matches) | ||
# list inputs | ||
comparison_data = ['cats', 'dogs', 'rats', 'birds'] | ||
input_data = ['owls', 'feline', 'doggies', 'dogs','chair'] | ||
nlp_link.load(comparison_data) | ||
matches = nlp_link.link_dataset(input_data) | ||
# Top match output | ||
print(matches) | ||
""" | ||
|
||
from sentence_transformers import SentenceTransformer | ||
import torch | ||
from tqdm import tqdm | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
import numpy as np | ||
import pandas as pd | ||
import random | ||
|
||
from typing import Union, Optional | ||
import logging | ||
|
||
from nlp_link.linker_utils import chunk_list | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# TO DO: cosine or euclidean? | ||
|
||
|
||
class NLPLinker(object): | ||
"""docstring for NLPLinker""" | ||
|
||
def __init__(self, batch_size=32, embed_chunk_size=500, match_chunk_size=10000): | ||
super(NLPLinker, self).__init__() | ||
self.batch_size = batch_size | ||
self.embed_chunk_size = embed_chunk_size | ||
self.match_chunk_size = match_chunk_size | ||
## Cleaning? | ||
|
||
def _process_dataset( | ||
self, | ||
input_data: Union[list, dict, pd.DataFrame], | ||
id_column: Optional[str] = None, | ||
text_column: Optional[str] = None, | ||
) -> dict: | ||
"""Check and process a dataset according to the input type | ||
Args: | ||
input_data (Union[list, dict, pd.DataFrame]) | ||
A list of texts or a dictionary of texts where the key is the unique id. | ||
If a list is given then a unique id will be assigned with the index order. | ||
Returns: | ||
dict: key is the id and the value is the text | ||
""" | ||
|
||
if isinstance(input_data, list): | ||
return {ix: text for ix, text in enumerate(input_data)} | ||
elif isinstance(input_data, dict): | ||
return input_data | ||
elif isinstance(input_data, pd.DataFrame): | ||
try: | ||
return dict(zip(input_data[id_column], input_data[text_column])) | ||
except: | ||
logger.warning( | ||
"Input is a dataframe, please specify id_column and text_column" | ||
) | ||
else: | ||
logger.warning( | ||
"The input_data input must be a dictionary, a list or pandas dataframe" | ||
) | ||
|
||
if not isinstance(input_data[0], str): | ||
logger.warning( | ||
"The input_data input must be a list of texts, or a dictionary where the values are texts" | ||
) | ||
|
||
def load( | ||
self, | ||
comparison_data: Union[list, dict], | ||
): | ||
""" | ||
Load the embedding model and embed the comparison dataset | ||
Args: | ||
comparison_data (Union[list, dict]): The comparison texts to find links to. | ||
A list of texts or a dictionary of texts where the key is the unique id. | ||
If a list is given then a unique id will be assigned with the index order. | ||
""" | ||
logger.info("Loading model") | ||
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") | ||
self.bert_model = SentenceTransformer( | ||
"sentence-transformers/all-MiniLM-L6-v2", device=device | ||
) | ||
self.bert_model.max_seq_length = 512 | ||
|
||
self.comparison_data = self._process_dataset(comparison_data) | ||
self.comparison_data_texts = list(self.comparison_data.values()) | ||
self.comparison_data_ids = list(self.comparison_data.keys()) | ||
|
||
self.comparison_embeddings = self._get_embeddings(self.comparison_data_texts) | ||
|
||
def _get_embeddings(self, text_list: list) -> np.array: | ||
""" | ||
Get embeddings for a list of texts | ||
Args: | ||
text_list (list): A lists of texts | ||
Returns: | ||
np.array: The embeddings for the input list of texts | ||
""" | ||
|
||
logger.info( | ||
f"Finding embeddings for {len(text_list)} texts chunked into {round(len(text_list)/self.embed_chunk_size)} chunks" | ||
) | ||
all_embeddings = [] | ||
for batch_texts in tqdm(chunk_list(text_list, self.embed_chunk_size)): | ||
all_embeddings.append( | ||
self.bert_model.encode( | ||
np.array(batch_texts), batch_size=self.batch_size | ||
) | ||
) | ||
all_embeddings = np.concatenate(all_embeddings) | ||
|
||
return all_embeddings | ||
|
||
def get_matches( | ||
self, | ||
input_data_ids: list, | ||
input_embeddings: np.array, | ||
comparison_data_ids: list, | ||
comparison_embeddings: np.array, | ||
top_n: int, | ||
drop_most_similar: bool = False, | ||
) -> dict: | ||
""" | ||
Find top matches across two datasets using their embeddings. | ||
Args: | ||
input_data_ids (list): The ids of the input texts. | ||
input_embeddings (np.array): Embeddings for the input texts. | ||
comparison_data_ids (list): The ids of the comparison texts. | ||
comparison_embeddings (np.array): Embeddings for the comparison texts. | ||
top_n (int): The number of top links to return in the output. | ||
drop_most_similar (bool, default = False): Whether to not output the most similar match, this would be set to True if you are matching a list with itself. | ||
Returns: | ||
dict: The top matches for each input id. | ||
""" | ||
|
||
logger.info( | ||
f"Finding the top dataset matches for {len(input_data_ids)} input texts chunked into {round(len(input_data_ids)/self.match_chunk_size)}" | ||
) | ||
|
||
if drop_most_similar: | ||
top_n = top_n + 1 | ||
start_n = 1 | ||
else: | ||
start_n = 0 | ||
|
||
# We chunk up comparisons otherwise it can crash | ||
matches_topn = {} | ||
for batch_indices in tqdm( | ||
chunk_list(range(len(input_data_ids)), n_chunks=self.match_chunk_size) | ||
): | ||
batch_input_ids = [input_data_ids[i] for i in batch_indices] | ||
batch_input_embeddings = [input_embeddings[i] for i in batch_indices] | ||
|
||
batch_similarities = cosine_similarity( | ||
batch_input_embeddings, comparison_embeddings | ||
) | ||
|
||
# Top links for each input text | ||
for input_ix, similarities in enumerate(batch_similarities): | ||
top_links = [] | ||
for comparison_ix in np.flip(np.argsort(similarities))[start_n:top_n]: | ||
# comparison data id + cosine similarity score | ||
top_links.append( | ||
[ | ||
comparison_data_ids[comparison_ix], | ||
similarities[comparison_ix], | ||
] | ||
) | ||
matches_topn[batch_input_ids[input_ix]] = top_links | ||
return matches_topn | ||
|
||
def link_dataset( | ||
self, | ||
input_data: Union[list, dict], | ||
top_n: int = 3, | ||
format_output: bool = True, | ||
drop_most_similar: bool = False, | ||
) -> dict: | ||
""" | ||
Link a dataset to the comparison dataset. | ||
Args: | ||
input_data (Union[list, dict]): The main dictionary to be linked to texts in the loaded comparison_data. | ||
A list of texts or a dictionary of texts where the key is the unique id. | ||
If a list is given then a unique id will be assigned with the index order. | ||
top_n (int, default = 3): The number of top links to return in the output. | ||
format_output (bool, default = True): If you'd like the output to be formatted to include the texts of | ||
the matched datasets or not (will just give the indices). | ||
drop_most_similar (bool, default = False): Whether to not output the most similar match, this would be set to True if you are matching a list with itself. | ||
Returns: | ||
dict: The keys are the ids of the input_data and the values are a list of lists of the top_n most similar | ||
ids from the comparison_data and a probability score. | ||
e.g. {'x': [['a', 0.75], ['c', 0.7]], 'y': [...]} | ||
""" | ||
|
||
try: | ||
logger.info( | ||
f"Comparing {len(input_data)} input texts to {len(self.comparison_embeddings)} comparison texts" | ||
) | ||
except: | ||
logger.warning( | ||
"self.comparison_embeddings does not exist - you may have not run load()" | ||
) | ||
|
||
input_data = self._process_dataset(input_data) | ||
input_data_texts = list(input_data.values()) | ||
input_data_ids = list(input_data.keys()) | ||
|
||
input_embeddings = self._get_embeddings(input_data_texts) | ||
|
||
def link_lists(list_1, list_2): | ||
""" | ||
Mock linker | ||
""" | ||
list_1_index = list(range(len(list_1))) | ||
list_2_index = list(range(len(list_2))) | ||
self.matches_topn = self.get_matches( | ||
input_data_ids, | ||
input_embeddings, | ||
self.comparison_data_ids, | ||
self.comparison_embeddings, | ||
top_n, | ||
drop_most_similar, | ||
) | ||
|
||
return [(i, random.choice(list_1_index)) for i in list_2_index] | ||
if format_output: | ||
# Format the output into a user friendly pandas format with the top link only | ||
df_output = pd.DataFrame( | ||
[ | ||
{ | ||
"input_id": input_id, | ||
"input_text": input_data[input_id], | ||
"link_id": link_data[0][0], | ||
"link_text": self.comparison_data[link_data[0][0]], | ||
"similarity": link_data[0][1], | ||
} | ||
for input_id, link_data in self.matches_topn.items() | ||
] | ||
) | ||
return df_output | ||
else: | ||
return self.matches_topn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
def chunk_list(orig_list, n_chunks): | ||
for i in range(0, len(orig_list), n_chunks): | ||
yield orig_list[i : i + n_chunks] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,77 @@ | ||
from nlp_link.linker import link_lists | ||
# Needed for Github Actions to not fail (see torch bug https://github.com/pytorch/pytorch/issues/121101) | ||
import torch | ||
|
||
torch.set_num_threads(1) | ||
|
||
def test_link_lists(): | ||
from nlp_link.linker import NLPLinker | ||
|
||
list_1 = ["dog", "cat"] | ||
list_2 = ["kitten", "puppy"] | ||
linked = link_lists(list_1, list_2) | ||
import numpy as np | ||
|
||
assert len(linked) == len(list_1) | ||
|
||
def test_NLPLinker_dict_input(): | ||
|
||
nlp_link = NLPLinker() | ||
|
||
comparison_data = {"a": "cats", "b": "dogs", "c": "rats", "d": "birds"} | ||
input_data = { | ||
"x": "owls", | ||
"y": "feline", | ||
"z": "doggies", | ||
"za": "dogs", | ||
"zb": "chair", | ||
} | ||
nlp_link.load(comparison_data) | ||
matches = nlp_link.link_dataset(input_data) | ||
|
||
assert len(matches) == len(input_data) | ||
assert len(set(matches["link_id"]).difference(set(comparison_data.keys()))) == 0 | ||
|
||
|
||
def test_NLPLinker_list_input(): | ||
|
||
nlp_link = NLPLinker() | ||
|
||
comparison_data = ["cats", "dogs", "rats", "birds"] | ||
input_data = ["owls", "feline", "doggies", "dogs", "chair"] | ||
nlp_link.load(comparison_data) | ||
matches = nlp_link.link_dataset(input_data) | ||
|
||
assert len(matches) == len(input_data) | ||
assert ( | ||
len(set(matches["link_id"]).difference(set(range(len(comparison_data))))) == 0 | ||
) | ||
|
||
|
||
def test_get_matches(): | ||
|
||
nlp_link = NLPLinker() | ||
|
||
matches_topn = nlp_link.get_matches( | ||
input_data_ids=["x", "y", "z"], | ||
input_embeddings=np.array( | ||
[[0.1, 0.13, 0.14], [0.12, 0.18, 0.15], [0.5, 0.9, 0.91]] | ||
), | ||
comparison_data_ids=["a", "b"], | ||
comparison_embeddings=np.array([[0.51, 0.99, 0.9], [0.1, 0.13, 0.14]]), | ||
top_n=1, | ||
) | ||
|
||
assert matches_topn["x"][0][0] == "b" | ||
assert matches_topn["y"][0][0] == "b" | ||
assert matches_topn["z"][0][0] == "a" | ||
|
||
|
||
def test_same_input(): | ||
|
||
nlp_link = NLPLinker() | ||
|
||
comparison_data = {"a": "cats", "b": "dogs", "c": "rats", "d": "birds"} | ||
input_data = comparison_data | ||
nlp_link.load(comparison_data) | ||
matches = nlp_link.link_dataset(input_data, drop_most_similar=False) | ||
|
||
assert all(matches["input_id"] == matches["link_id"]) | ||
|
||
matches = nlp_link.link_dataset(input_data, drop_most_similar=True) | ||
|
||
assert all(matches["input_id"] != matches["link_id"]) |