diff --git a/FlagEmbedding/inference/embedder/decoder_only/base.py b/FlagEmbedding/inference/embedder/decoder_only/base.py index 9c3deb6d..2912ca05 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/base.py +++ b/FlagEmbedding/inference/embedder/decoder_only/base.py @@ -224,7 +224,8 @@ def encode_single_device( # tokenize without padding to get the correct length all_inputs = [] - for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'): + for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize', + disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] inputs_batch = self.tokenizer( sentences_batch, diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index 790829a1..1206b8a4 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -178,9 +178,12 @@ def stop_self_query_pool(self): if self.query_pool is not None: self.stop_multi_process_pool(self.query_pool) self.query_pool = None - self.model.to('cpu') + try: + self.model.to('cpu') + torch.cuda.empty_cache() + except: + pass gc.collect() - torch.cuda.empty_cache() def encode_queries( self, @@ -483,7 +486,8 @@ def encode_single_device( # tokenize without padding to get the correct length all_inputs = [] - for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'): + for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize', + disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] inputs_batch = self.tokenizer( sentences_batch, diff --git a/FlagEmbedding/inference/embedder/encoder_only/base.py b/FlagEmbedding/inference/embedder/encoder_only/base.py index fe11b228..6b27ec14 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/base.py +++ b/FlagEmbedding/inference/embedder/encoder_only/base.py @@ -205,7 +205,8 @@ def encode_single_device( # tokenize without padding to get the correct length all_inputs = [] - for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'): + for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize', + disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] inputs_batch = self.tokenizer( sentences_batch, diff --git a/FlagEmbedding/inference/embedder/encoder_only/m3.py b/FlagEmbedding/inference/embedder/encoder_only/m3.py index 8a8aa041..42c207d7 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/m3.py +++ b/FlagEmbedding/inference/embedder/encoder_only/m3.py @@ -369,7 +369,8 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list): # tokenize without padding to get the correct length all_inputs = [] - for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'): + for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize', + disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] inputs_batch = self.tokenizer( sentences_batch, diff --git a/FlagEmbedding/inference/reranker/decoder_only/base.py b/FlagEmbedding/inference/reranker/decoder_only/base.py index 7d0d1645..b50a697e 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/base.py +++ b/FlagEmbedding/inference/reranker/decoder_only/base.py @@ -309,7 +309,8 @@ def compute_score_single_gpu( # tokenize without padding to get the correct length all_queries_inputs = [] all_passages_inputs = [] - for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"): + for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize", + disable=len(sentence_pairs) < 128): sentences_batch = sentence_pairs[start_index:start_index + batch_size] queries = [s[0] for s in sentences_batch] passages = [s[1] for s in sentences_batch] diff --git a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py index dce7393c..8e50271f 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/layerwise.py +++ b/FlagEmbedding/inference/reranker/decoder_only/layerwise.py @@ -191,7 +191,8 @@ def compute_score_single_gpu( # tokenize without padding to get the correct length all_queries_inputs = [] all_passages_inputs = [] - for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"): + for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize", + disable=len(sentence_pairs) < 128): sentences_batch = sentence_pairs[start_index:start_index + batch_size] queries = [s[0] for s in sentences_batch] passages = [s[1] for s in sentences_batch] diff --git a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py index 87c0027d..a9bea311 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py +++ b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py @@ -262,7 +262,8 @@ def compute_score_single_gpu( # tokenize without padding to get the correct length all_queries_inputs = [] all_passages_inputs = [] - for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"): + for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize", + disable=len(sentence_pairs) < 128): sentences_batch = sentence_pairs[start_index:start_index + batch_size] queries = [s[0] for s in sentences_batch] passages = [s[1] for s in sentences_batch] diff --git a/FlagEmbedding/inference/reranker/encoder_only/base.py b/FlagEmbedding/inference/reranker/encoder_only/base.py index 9a1abebe..9af0b2d8 100644 --- a/FlagEmbedding/inference/reranker/encoder_only/base.py +++ b/FlagEmbedding/inference/reranker/encoder_only/base.py @@ -121,7 +121,8 @@ def compute_score_single_gpu( # tokenize without padding to get the correct length all_inputs = [] - for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize"): + for start_index in trange(0, len(sentence_pairs), batch_size, desc="pre tokenize", + disable=len(sentence_pairs) < 128): sentences_batch = sentence_pairs[start_index:start_index + batch_size] queries = [s[0] for s in sentences_batch] passages = [s[1] for s in sentences_batch]