From 5635bc12dd78bb242a2e411a3187eafdca25d2a2 Mon Sep 17 00:00:00 2001 From: lizgzil Date: Fri, 8 Dec 2023 12:05:49 +0000 Subject: [PATCH] Use gpu option in bervectorizer --- ojd_daps_skills/utils/bert_vectorizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ojd_daps_skills/utils/bert_vectorizer.py b/ojd_daps_skills/utils/bert_vectorizer.py index 0b401bf9..4647a1e6 100644 --- a/ojd_daps_skills/utils/bert_vectorizer.py +++ b/ojd_daps_skills/utils/bert_vectorizer.py @@ -2,6 +2,7 @@ import time from ojd_daps_skills import logger import logging +import torch class BertVectorizer: @@ -13,7 +14,7 @@ class BertVectorizer: def __init__( self, bert_model_name="sentence-transformers/all-MiniLM-L6-v2", - multi_process=True, + multi_process=False, batch_size=32, verbose=True, ): @@ -27,7 +28,8 @@ def __init__( logger.setLevel(logging.ERROR) def fit(self, *_): - self.bert_model = SentenceTransformer(self.bert_model_name) + device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") + self.bert_model = SentenceTransformer(self.bert_model_name, device=device) self.bert_model.max_seq_length = 512 return self