Skip to content

Commit

Permalink
Merge pull request FlagOpen#1221 from 545999961/master
Browse files Browse the repository at this point in the history
update stop pool
  • Loading branch information
545999961 authored Nov 14, 2024
2 parents 1b971d0 + d681574 commit db20af3
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 50 deletions.
4 changes: 2 additions & 2 deletions FlagEmbedding/abc/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ def __call__(
dataset_name=dataset_name,
)
no_reranker_search_results_dict[split] = search_results
retriever.stop_multi_process_pool()
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)

retriever.stop_multi_process_pool()
# Reranking Stage
if reranker is not None:
reranker_search_results_save_dir = os.path.join(
Expand Down Expand Up @@ -254,10 +254,10 @@ def __call__(
split=split,
dataset_name=dataset_name,
)
reranker.stop_multi_process_pool()
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
reranker.stop_multi_process_pool()

@staticmethod
def save_search_results(
Expand Down
26 changes: 14 additions & 12 deletions FlagEmbedding/abc/evaluation/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def __str__(self) -> str:
return os.path.basename(self.embedder.model.config._name_or_path)

def stop_multi_process_pool(self):
if self.embedder.pool is not None:
self.embedder.stop_multi_process_pool(self.embedder.pool)
self.embedder.pool = None
self.embedder.model.to('cpu')
gc.collect()
torch.cuda.empty_cache()
self.embedder.stop_self_pool()
# if self.embedder.pool is not None:
# self.embedder.stop_multi_process_pool(self.embedder.pool)
# self.embedder.pool = None
# self.embedder.model.to('cpu')
# gc.collect()
# torch.cuda.empty_cache()

@abstractmethod
def __call__(
Expand Down Expand Up @@ -168,12 +169,13 @@ def __str__(self) -> str:
return os.path.basename(self.reranker.model.config._name_or_path)

def stop_multi_process_pool(self):
if self.reranker.pool is not None:
self.reranker.stop_multi_process_pool(self.reranker.pool)
self.reranker.pool = None
self.reranker.model.to('cpu')
gc.collect()
torch.cuda.empty_cache()
self.reranker.stop_self_pool()
# if self.reranker.pool is not None:
# self.reranker.stop_multi_process_pool(self.reranker.pool)
# self.reranker.pool = None
# self.reranker.model.to('cpu')
# gc.collect()
# torch.cuda.empty_cache()

def __call__(
self,
Expand Down
11 changes: 11 additions & 0 deletions FlagEmbedding/abc/inference/AbsEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from multiprocessing import Queue

import math
import gc
import torch
import numpy as np
from transformers import is_torch_npu_available
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
convert_to_numpy: bool = True,
**kwargs: Any,
):
query_instruction_format = query_instruction_format.replace('\\n', '\n')
self.model_name_or_path = model_name_or_path
self.normalize_embeddings = normalize_embeddings
self.use_fp16 = use_fp16
Expand All @@ -74,6 +76,14 @@ def __init__(
self.tokenizer = None
self.model = None
self.pool = None

def stop_self_pool(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
self.pool = None
self.model.to('cpu')
gc.collect()
torch.cuda.empty_cache()

@staticmethod
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
Expand Down Expand Up @@ -355,6 +365,7 @@ def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"],

pool["input"].close()
pool["output"].close()
pool = None

# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
def encode_multi_process(
Expand Down
9 changes: 9 additions & 0 deletions FlagEmbedding/abc/inference/AbsReranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from multiprocessing import Queue

import math
import gc
import torch
import numpy as np
from tqdm import tqdm, trange
Expand Down Expand Up @@ -77,6 +78,14 @@ def __init__(
self.tokenizer = None
self.pool = None

def stop_self_pool(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
self.pool = None
self.model.to('cpu')
gc.collect()
torch.cuda.empty_cache()

@staticmethod
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
"""
Expand Down
3 changes: 1 addition & 2 deletions FlagEmbedding/evaluation/beir/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,11 @@ def __call__(
sub_dataset_name=sub_dataset_name,
)
no_reranker_search_results_dict[split] = search_results
retriever.stop_multi_process_pool()
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)

retriever.stop_multi_process_pool()

# Reranking Stage
if reranker is not None:
reranker_search_results_save_dir = os.path.join(
Expand Down
5 changes: 5 additions & 0 deletions FlagEmbedding/evaluation/mteb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

logger = logging.getLogger(__name__)

def ensure_dir(file_path):
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)

class MTEBEvalRunner(AbsEvalRunner):
def __init__(
Expand Down Expand Up @@ -147,6 +151,7 @@ def run(self):
evaluation = mteb.MTEB(tasks=[task])
results = evaluation.run(self.retriever, output_folder=f"{output_folder}/{str(self.retriever)}")

ensure_dir(self.eval_args.eval_output_path)
logger.info("Start computing metrics. Only save results as json.")
tasks_results = self.read_results(f"{output_folder}/{str(self.retriever)}/no_model_name_available/no_revision_available", new_tasks)
self.output_json(tasks_results, self.eval_args.eval_output_path)
4 changes: 2 additions & 2 deletions FlagEmbedding/inference/embedder/decoder_only/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tqdm import tqdm
from tqdm import tqdm, trange
from typing import cast, Any, List, Union, Optional

import torch
Expand Down Expand Up @@ -224,7 +224,7 @@ def encode_single_device(

# tokenize without padding to get the correct length
all_inputs = []
for start_index in range(0, len(sentences), batch_size):
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
Expand Down
72 changes: 43 additions & 29 deletions FlagEmbedding/inference/embedder/decoder_only/icl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tqdm import tqdm
from tqdm import tqdm, trange
from typing import cast, Any, List, Union, Optional

import queue
Expand Down Expand Up @@ -69,6 +69,7 @@ def __init__(
use_fp16: bool = True,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "<instruct>{}\n<query>{}", # specify the format of query_instruction_for_retrieval
suffix: str = '\n<response>',
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
# Additional parameters for ICLLLMEmbedder
examples_for_task: Optional[List[dict]] = None,
Expand All @@ -82,6 +83,8 @@ def __init__(
convert_to_numpy: bool = True,
**kwargs: Any,
):
query_instruction_format = query_instruction_format.replace('\\n', '\n')
examples_instruction_format = examples_instruction_format.replace('\\n', '\n')
super().__init__(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
Expand Down Expand Up @@ -113,7 +116,15 @@ def __init__(
raise ValueError("Pooling method must be 'last_token' for LLM-based models.")

self.set_examples()
self.suffix = '\n<response>'
self.suffix = suffix

self.query_pool = None

def __del__(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
if self.query_pool is not None:
self.stop_multi_process_pool(self.query_pool)

def set_examples(self, examples_for_task: Optional[List[dict]] = None):
"""Set the prefix to the provided examples.
Expand Down Expand Up @@ -198,16 +209,19 @@ def encode_queries(
**kwargs
)

pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker)
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
self.pool = None
if self.query_pool is None:
self.query_pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker)
embeddings = self.encode_multi_process(
queries,
pool,
self.query_pool,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
**kwargs
)
self.stop_multi_process_pool(pool)
return embeddings

def encode_corpus(
Expand All @@ -230,6 +244,9 @@ def encode_corpus(
Returns:
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
"""
if self.query_pool is not None:
self.stop_multi_process_pool(self.query_pool)
self.query_pool = None
return super().encode_corpus(
corpus,
batch_size=batch_size,
Expand Down Expand Up @@ -338,16 +355,27 @@ def encode_queries_single_device(
suffix_ids = self.tokenizer(self.suffix, add_special_tokens=False)['input_ids']

_len_1 = len(self.tokenizer('<s>', add_special_tokens=False)['input_ids'])
_len_2 = len(self.tokenizer('\n<response></s>', add_special_tokens=False)['input_ids'])
_len_2 = len(self.tokenizer(f'{self.suffix}</s>', add_special_tokens=False)['input_ids'])
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length + 8) // 8 * 8 + 8

# tokenize without padding to get the correct length
all_inputs = []
for start_index in range(0, len(input_texts), batch_size):
for start_index in trange(0, len(input_texts), batch_size, desc='pre tokenize'):
sentences_batch = input_texts[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
truncation=True,
max_length=max_length,
max_length=max_length - _len_1 - _len_2,
add_special_tokens=False,
**kwargs
)
sentences_batch = self.tokenizer.batch_decode(inputs_batch['input_ids'])
for i in range(len(sentences_batch)):
sentences_batch[i] = self.prefix + sentences_batch[i] + self.suffix
inputs_batch = self.tokenizer(
sentences_batch,
truncation=True,
max_length=new_max_length,
**kwargs
)
inputs_batch = [{
Expand Down Expand Up @@ -385,30 +413,16 @@ def encode_queries_single_device(
all_embeddings = []
for start_index in tqdm(range(0, len(sentences_sorted), batch_size), desc="Inference Embeddings",
disable=len(sentences_sorted) < 256):
sentences_batch = sentences_sorted[start_index:start_index + batch_size]
inputs = self.tokenizer(
sentences_batch,
max_length=max_length - _len_1 - _len_2,
return_token_type_ids=False,
truncation=True,
return_tensors=None,
add_special_tokens=False
)
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length + 8) // 8 * 8 + 8
sentences_batch = self.tokenizer.batch_decode(inputs['input_ids'])
for i in range(len(sentences_batch)):
sentences_batch[i] = self.prefix + sentences_batch[i] + self.suffix
inputs = self.tokenizer(
sentences_batch,
inputs_batch = all_inputs_sorted[start_index:start_index + batch_size]
inputs_batch = self.tokenizer.pad(
inputs_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=new_max_length,
add_special_tokens=True
**kwargs
).to(device)

last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, inputs['attention_mask'])
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
Expand Down Expand Up @@ -469,7 +483,7 @@ def encode_single_device(

# tokenize without padding to get the correct length
all_inputs = []
for start_index in range(0, len(sentences), batch_size):
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
Expand Down
4 changes: 2 additions & 2 deletions FlagEmbedding/inference/embedder/encoder_only/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tqdm import tqdm
from tqdm import tqdm, trange
from typing import cast, Any, List, Union, Optional

import torch
Expand Down Expand Up @@ -205,7 +205,7 @@ def encode_single_device(

# tokenize without padding to get the correct length
all_inputs = []
for start_index in range(0, len(sentences), batch_size):
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/embedder/encoder_only/m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):

# tokenize without padding to get the correct length
all_inputs = []
for start_index in range(0, len(sentences), batch_size):
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize'):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
Expand Down

0 comments on commit db20af3

Please sign in to comment.