From 1b446b05a786525b0f0bb927bb3b509092e07345 Mon Sep 17 00:00:00 2001 From: lizgzil Date: Tue, 3 Sep 2024 10:30:14 +0100 Subject: [PATCH] set torch threads to try to fix github actions test runs - 2 --- nlp_link/linker.py | 6 +++++- tests/test_linker.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/nlp_link/linker.py b/nlp_link/linker.py index 1feabc9..6a60dd5 100644 --- a/nlp_link/linker.py +++ b/nlp_link/linker.py @@ -26,6 +26,7 @@ """ from sentence_transformers import SentenceTransformer +import torch from tqdm import tqdm from sklearn.metrics.pairwise import cosine_similarity import numpy as np @@ -100,7 +101,10 @@ def load( If a list is given then a unique id will be assigned with the index order. """ logger.info("Loading model") - self.bert_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") + device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") + self.bert_model = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2", device=device + ) self.bert_model.max_seq_length = 512 self.comparison_data = self._process_dataset(comparison_data) diff --git a/tests/test_linker.py b/tests/test_linker.py index 89551a3..d7b35cf 100644 --- a/tests/test_linker.py +++ b/tests/test_linker.py @@ -1,12 +1,12 @@ -from nlp_link.linker import NLPLinker - -import numpy as np - # Needed for Github Actions to not fail (see torch bug https://github.com/pytorch/pytorch/issues/121101) import torch torch.set_num_threads(1) +from nlp_link.linker import NLPLinker + +import numpy as np + def test_NLPLinker_dict_input():