Skip to content

Commit

Permalink
Use gpu option in bervectorizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lizgzil committed Dec 8, 2023
1 parent 66bd2e1 commit 5635bc1
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ojd_daps_skills/utils/bert_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from ojd_daps_skills import logger
import logging
import torch


class BertVectorizer:
Expand All @@ -13,7 +14,7 @@ class BertVectorizer:
def __init__(
self,
bert_model_name="sentence-transformers/all-MiniLM-L6-v2",
multi_process=True,
multi_process=False,
batch_size=32,
verbose=True,
):
Expand All @@ -27,7 +28,8 @@ def __init__(
logger.setLevel(logging.ERROR)

def fit(self, *_):
self.bert_model = SentenceTransformer(self.bert_model_name)
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
self.bert_model = SentenceTransformer(self.bert_model_name, device=device)
self.bert_model.max_seq_length = 512
return self

Expand Down

0 comments on commit 5635bc1

Please sign in to comment.