diff --git a/README.md b/README.md index 6ec33eee..a37d8d77 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ BGE (BAAI General Embedding) focuses on retrieval-augmented LLMs, consisting of ## News -- 05/12/2024: :book: We built the BGE documentation for centralized BGE information and materials. +- 05/12/2024: :book: We built the [BGE documentation](www.bge-model.com) for centralized BGE information and materials! - 10/29/2024: :earth_asia: We created WeChat group for BGE. Scan the [QR code](./imgs/BGE_WeChat_Group.png) to join the group chat! To get the first hand message about our updates and new release, or having any questions or ideas, join us now! - bge_wechat_group diff --git a/Tutorials/7_Fine-tuning/7.1.1_Data_preparation.ipynb b/Tutorials/7_Fine-tuning/7.1.1_Data_preparation.ipynb new file mode 100644 index 00000000..89cffa05 --- /dev/null +++ b/Tutorials/7_Fine-tuning/7.1.1_Data_preparation.ipynb @@ -0,0 +1,723 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data preparation for fine-tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we will show an example of the first step for fine-tuning: dataset preparation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 0. Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# % pip install -U datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"HF_ENDPOINT\"]=\"https://hf-mirror.com\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Suppose we are willing to fine-tune our model for financial tasks. We found an open-source dataset that could be useful: [financial-qa-10k](https://huggingface.co/datasets/virattt/financial-qa-10K). Let's see how to properly prepare our dataset for fine-tuning." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The raw dataset has the following structure:\n", + "- 5 columns of: 'question', 'answer', 'context', 'ticker', and 'filing'.\n", + "- 7000 rows." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['question', 'answer', 'context', 'ticker', 'filing'],\n", + " num_rows: 7000\n", + "})" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"virattt/financial-qa-10K\", split=\"train\")\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Data for Fine-tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Construct the dataset to the following format:\n", + "\n", + "``` python\n", + "{\"query\": str, \"pos\": List[str], \"neg\":List[str], \"pos_scores\": List[int], \"neg_scores\": List[int], \"prompt\": str, \"type\": str}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts. `pos_scores` is a list of scores corresponding to the query and pos, `neg_scores` is a list of scores corresponding to the `query` and `neg`, if you don't use knowledge distillation, it can be ignored. `prompt` is the prompt used for the query, it will cover query_instruction_for_retrieval. `type` is used for bge-en-icl, it includes `normal`, `symmetric_class`, `symmetric_clustering`, .etc. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We select the columns 'question' and 'context' as our query and answer(pos), and rename the columns. Then add the 'id' column for later evaluation use." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n", + " 'pos': 'Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.',\n", + " 'id': '0'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds = ds.select_columns(column_names=[\"question\", \"context\"])\n", + "ds = ds.rename_column(\"question\", \"query\")\n", + "ds = ds.rename_column(\"context\", \"pos\")\n", + "ds = ds.add_column(\"id\", [str(i) for i in range(len(ds))])\n", + "ds[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Negative examples are important during the training of embedding models. Our initial dataset does not come with negative texts. Thus we directly sample a few from the whole corpus." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 7000/7000 [00:00<00:00, 22336.83 examples/s]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.random.seed(520)\n", + "neg_num = 10\n", + "\n", + "def str_to_lst(data):\n", + " data[\"pos\"] = [data[\"pos\"]]\n", + " return data\n", + "\n", + "# sample negative texts\n", + "new_col = []\n", + "for i in range(len(ds)):\n", + " ids = np.random.randint(0, len(ds), size=neg_num)\n", + " while i in ids:\n", + " ids = np.random.randint(0, len(ds), size=neg_num)\n", + " neg = [ds[i.item()][\"pos\"] for i in ids]\n", + " new_col.append(neg)\n", + "ds = ds.add_column(\"neg\", new_col)\n", + "\n", + "# change the key of 'pos' to a list\n", + "ds = ds.map(str_to_lst)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly, we add the prompt which is used for query. It will be the `query_instruction_for_retrieval` during inference." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "instruction = \"Represent this sentence for searching relevant passages: \"\n", + "ds = ds.add_column(\"prompt\", [instruction]*len(ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now a single row of the dataset is:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'query': 'What area did NVIDIA initially focus on before expanding to other computationally intensive fields?',\n", + " 'pos': ['Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields.'],\n", + " 'id': '0',\n", + " 'neg': ['Kroger expects that its value creation model will deliver total shareholder return within a target range of 8% to 11% over time.',\n", + " 'CSB purchased First Mortgages of $2.9 billion during 2023.',\n", + " 'See Note 13 to our Consolidated Financial Statements for information on certain legal proceedings for which there are contingencies.',\n", + " 'Diluted earnings per share were $16.69 in fiscal 2022 compared to $15.53 in fiscal 2021.',\n", + " 'In the year ended December 31, 2023, Total net sales and revenue increased primarily due to: (1) increased net wholesale volumes primarily due to increased sales of crossover vehicles and full-size pickup trucks, partially offset by decreased sales of mid-size pickup trucks; (2) favorable Price as a result of low dealer inventory levels and strong demand for our products; (3) favorable Mix associated with increased sales of full-size pickup trucks and full-size SUVs and decreased sales of vans, passenger cars and mid-size pickup trucks, partially offset by increased sales of crossover vehicles; and (4) favorable Other due to increased sales of parts and accessories.',\n", + " 'As of December 31, 2023, we had 3,157 full-time employees.',\n", + " 'Item 3. Legal Proceedings. The information contained in Note 18 ‘‘Commitments and Contingencies’’ included in Item 8 of this 10-K is incorporated herein by reference.',\n", + " 'Under the amended 2019 Secured Facility, the maturity date is set to July 20, 2026.',\n", + " 'Accounts receivable for Las Vegas Sands Corp. on December 31, 2023, totaled $685 million, with a provision for credit losses of $201 million, resulting in a net balance of $484 million.',\n", + " 'Operating expenses as a percentage of segment net sales decreased 25 basis points for fiscal 2023 when compared to the previous fiscal year, primarily driven by strong sales growth and lower incremental COVID-19 related costs, partially offset by increased wage costs.'],\n", + " 'prompt': 'Represent this sentence for searching relevant passages: '}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we split the dataset into training set and testing set." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)\n", + "train = split[\"train\"]\n", + "test = split[\"test\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we are ready to store the data for later fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 39.73ba/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "16583481" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train.to_json(\"ft_data/training.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Data for Evaluation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The last step is to construct the testing dataset following the [format](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/evaluation#8-custom-dataset) for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['query', 'pos', 'id', 'neg', 'prompt'],\n", + " num_rows: 700\n", + "})" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First select the columns for queries:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': '1289',\n", + " 'text': 'How does Starbucks recognize the interest and penalties related to income tax matters on their financial statements?'}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "queries = test.select_columns(column_names=[\"id\", \"query\"])\n", + "queries = queries.rename_column(\"query\", \"text\")\n", + "queries[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then select the columns for corpus:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "corpus = ds.select_columns(column_names=[\"id\", \"pos\"])\n", + "corpus = corpus.rename_column(\"pos\", \"text\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, make the qrels that indicating the relations of queries and corresponding corpus\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Flattening the indices: 100%|██████████| 700/700 [00:00<00:00, 180956.10 examples/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'qid': '1289', 'docid': '1289', 'relevance': 1}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qrels = test.select_columns([\"id\"])\n", + "qrels = qrels.rename_column(\"id\", \"qid\")\n", + "qrels = qrels.add_column(\"docid\", list(test[\"id\"]))\n", + "qrels = qrels.add_column(\"relevance\", [1]*len(test))\n", + "qrels[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Store the training set" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 210.42ba/s]\n", + "Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 261.19ba/s]\n", + "Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 591.08ba/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "30574" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "queries.to_json(\"ft_data/test_queries.jsonl\")\n", + "corpus.to_json(\"ft_data/corpus.jsonl\")\n", + "qrels.to_json(\"ft_data/test_qrels.jsonl\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finetune" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from FlagEmbedding import FlagModel\n", + "\n", + "finetuned_path = \"test_encoder_only_base_bge-large-en-v1.5\"\n", + "model_name = \"BAAI/bge-large-en-v1.5\"\n", + "model = FlagModel(finetuned_path, \n", + "# model = FlagModel(model_name,\n", + " query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n", + " devices=[0,1],\n", + " use_fp16=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "initial target device: 100%|██████████| 2/2 [00:30<00:00, 15.31s/it]\n", + "pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 116.32it/s]\n", + "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", + "pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 123.47it/s]\n", + "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", + " warnings.warn(\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", + " warnings.warn(\n", + "Inference Embeddings: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s]\n", + "Inference Embeddings: 100%|██████████| 2/2 [00:00<00:00, 13.14it/s]\n", + "Chunks: 100%|██████████| 2/2 [00:05<00:00, 2.56s/it]\n", + "pre tokenize: 100%|██████████| 14/14 [00:00<00:00, 55.58it/s]\n", + "pre tokenize: 100%|██████████| 14/14 [00:00<00:00, 27.82it/s]\n", + "Inference Embeddings: 100%|██████████| 14/14 [00:02<00:00, 6.24it/s]\n", + "Inference Embeddings: 100%|██████████| 14/14 [00:03<00:00, 4.07it/s]\n", + "Chunks: 100%|██████████| 2/2 [00:04<00:00, 2.05s/it]\n" + ] + } + ], + "source": [ + "queries_text = [q[1] for q in queries.items()]\n", + "corpus_text = [corpus[str(i)][0] for i in range(len(corpus))]\n", + "\n", + "queries_embeddings = model.encode_queries(queries_text)\n", + "corpus_embeddings = model.encode_corpus(corpus_text)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total number of vectors: 7000\n" + ] + } + ], + "source": [ + "import faiss\n", + "import numpy as np\n", + "\n", + "# get the length of our embedding vectors, vectors by bge-base-en-v1.5 have length 768\n", + "dim = corpus_embeddings.shape[-1]\n", + "\n", + "# create the faiss index and store the corpus embeddings into the vector space\n", + "index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n", + "# corpus_embeddings = corpus_embeddings.astype(np.float32)\n", + "# train and add the embeddings to the index\n", + "index.train(corpus_embeddings)\n", + "index.add(corpus_embeddings)\n", + "\n", + "print(f\"total number of vectors: {index.ntotal}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Searching: 100%|██████████| 22/22 [00:00<00:00, 31.84it/s]\n" + ] + } + ], + "source": [ + "from tqdm import tqdm\n", + "\n", + "query_size = len(queries_embeddings)\n", + "\n", + "all_scores = []\n", + "all_indices = []\n", + "\n", + "for i in tqdm(range(0, query_size, 32), desc=\"Searching\"):\n", + " j = min(i + 32, query_size)\n", + " query_embedding = queries_embeddings[i: j]\n", + " score, indice = index.search(query_embedding.astype(np.float32), k=100)\n", + " all_scores.append(score)\n", + " all_indices.append(indice)\n", + "\n", + "all_scores = np.concatenate(all_scores, axis=0)\n", + "all_indices = np.concatenate(all_indices, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "results = {}\n", + "for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)):\n", + " results[queries_ids[idx]] = {}\n", + " for score, index in zip(scores, indices):\n", + " if index != -1:\n", + " results[queries_ids[idx]][corpus_ids[index]] = float(score)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'NDCG@10': 0.84061, 'NDCG@100': 0.85484})\n", + "defaultdict(, {'MAP@10': 0.81157, 'MAP@100': 0.81471})\n", + "defaultdict(, {'Recall@10': 0.93, 'Recall@100': 0.99429})\n", + "defaultdict(, {'P@10': 0.093, 'P@100': 0.00994})\n", + "defaultdict(, {'MRR@10': 0.81157, 'MRR@100': 0.81471})\n" + ] + } + ], + "source": [ + "from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr\n", + "\n", + "k_values = [10,100]\n", + "eval_res = evaluate_metrics(qrels, results, k_values)\n", + "mrr = evaluate_mrr(qrels, results, k_values)\n", + "\n", + "for res in eval_res:\n", + " print(res)\n", + "print(mrr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'NDCG@1': 0.58286, 'NDCG@5': 0.68588, 'NDCG@10': 0.70405})\n", + "defaultdict(, {'Recall@1': 0.58286, 'Recall@5': 0.76714, 'Recall@10': 0.82286})\n" + ] + } + ], + "source": [ + "# Original test result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'NDCG@1': 0.75571, 'NDCG@5': 0.84706, 'NDCG@10': 0.85623})\n", + "defaultdict(, {'Recall@1': 0.75571, 'Recall@5': 0.92286, 'Recall@10': 0.95143})\n" + ] + } + ], + "source": [ + "# Fake test result" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[6.453125]\n" + ] + } + ], + "source": [ + "from FlagEmbedding import FlagReranker\n", + "\n", + "reranker = FlagReranker(\n", + " 'BAAI/bge-reranker-base', \n", + " query_max_length=256,\n", + " use_fp16=True,\n", + " devices=['cuda:1'],\n", + ")\n", + "\n", + "score = reranker.compute_score(['I am happy to help', 'Assisting you is my pleasure'])\n", + "print(score)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Tutorials/7_Fine-tuning/7.1.2_Fine-tune.ipynb b/Tutorials/7_Fine-tuning/7.1.2_Fine-tune.ipynb new file mode 100644 index 00000000..c8025630 --- /dev/null +++ b/Tutorials/7_Fine-tuning/7.1.2_Fine-tune.ipynb @@ -0,0 +1,3734 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous section, we went through how to construct training and testing data properly. In this tutorial, we will actually fine-tune the model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note to fine-tune BGE models using FlagEmbedding, we need to install the package with the finetune dependency:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "% pip install -U FlagEmbedding[finetune]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine-tune" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below are the arguments for fine-tuning:\n", + "\n", + "The following arguments are for model:\n", + "- `model_name_or_path`: The model checkpoint for initialization.\n", + "- `config_name`: Pretrained config name or path if not the same as model_name.\n", + "- `tokenizer_name`: Pretrained tokenizer name or path if not the same as model_name.\n", + "- `cache_dir`: Where do you want to store the pre-trained models downloaded from s3.\n", + "- `trust_remote_code`: Trust remote code\n", + "- `token`: The token to use when accessing the model.\n", + "\n", + "The following arguments are for data:\n", + "- `train_data`: One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data. Argument type: multiple.\n", + "- `cache_path`: Where do you want to store the cached data.\n", + "- `train_group_size`: (No metadata provided)\n", + "- `query_max_len`: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.\n", + "- `passage_max_len`: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.\n", + "- `pad_to_multiple_of`: If set will pad the sequence to be a multiple of the provided value.\n", + "- `max_example_num_per_dataset`: The max number of examples for each dataset.\n", + "- `query_instruction_for_retrieval`: Instruction for query.\n", + "- `query_instruction_format`: Format for query instruction.\n", + "- `knowledge_distillation`: Use knowledge distillation when `pos_scores: List[float]` and `neg_scores: List[float]` are in features of training data.\n", + "- `passage_instruction_for_retrieval`: Instruction for passage.\n", + "- `passage_instruction_format`: Format for passage instruction.\n", + "- `shuffle_ratio`: The ratio of shuffling the text.\n", + "- `same_dataset_within_batch`: All samples in the same batch comes from the same dataset.\n", + "- `small_threshold`: The threshold of small dataset. All small dataset in the same directory will be merged into one dataset.\n", + "- `drop_threshold`: The threshold for dropping merged small dataset. If the number of examples in the merged small dataset is less than this threshold, it will be dropped.\n", + "\n", + "And the following extra arguments:\n", + "- `negatives_cross_device`: Share negatives across devices.\n", + "- `temperature`: Temperature used for similarity score.\n", + "- `fix_position_embedding`: Freeze the parameters of position embeddings.\n", + "- `sentence_pooling_method`: The pooling method. Available options: cls, mean, last_token. Default: cls.\n", + "- `normalize_embeddings`: Whether to normalize the embeddings.\n", + "- `sub_batch_size`: Sub batch size for training.\n", + "- `kd_loss_type`: The loss type for knowledge distillation. Available options: kl_div, m3_kd_loss. Default: kl_div." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] \n", + "W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] *****************************************\n", + "W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] *****************************************\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", + " warnings.warn(\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-12-23 06:27:31,423] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[2024-12-23 06:27:31,424] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n", + "[2024-12-23 06:27:40,529] [INFO] [comm.py:652:init_distributed] cdb=None\n", + "[2024-12-23 06:27:40,529] [INFO] [comm.py:652:init_distributed] cdb=None\n", + "[2024-12-23 06:27:40,529] [INFO] [comm.py:683:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "12/23/2024 06:27:40 - WARNING - FlagEmbedding.abc.finetune.embedder.AbsRunner - Process rank: 0, device: cuda:0, n_gpu: 1, distributed training: True, 16-bits training: True\n", + "12/23/2024 06:27:40 - INFO - FlagEmbedding.abc.finetune.embedder.AbsRunner - Training/evaluation parameters AbsEmbedderTrainingArguments(\n", + "_n_gpu=1,\n", + "accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},\n", + "adafactor=False,\n", + "adam_beta1=0.9,\n", + "adam_beta2=0.999,\n", + "adam_epsilon=1e-08,\n", + "auto_find_batch_size=False,\n", + "batch_eval_metrics=False,\n", + "bf16=False,\n", + "bf16_full_eval=False,\n", + "data_seed=None,\n", + "dataloader_drop_last=True,\n", + "dataloader_num_workers=0,\n", + "dataloader_persistent_workers=False,\n", + "dataloader_pin_memory=True,\n", + "dataloader_prefetch_factor=None,\n", + "ddp_backend=None,\n", + "ddp_broadcast_buffers=None,\n", + "ddp_bucket_cap_mb=None,\n", + "ddp_find_unused_parameters=None,\n", + "ddp_timeout=1800,\n", + "debug=[],\n", + "deepspeed=config/ds_stage0.json,\n", + "disable_tqdm=False,\n", + "dispatch_batches=None,\n", + "do_eval=False,\n", + "do_predict=False,\n", + "do_train=False,\n", + "eval_accumulation_steps=None,\n", + "eval_delay=0,\n", + "eval_do_concat_batches=True,\n", + "eval_on_start=False,\n", + "eval_steps=None,\n", + "eval_strategy=IntervalStrategy.NO,\n", + "eval_use_gather_object=False,\n", + "evaluation_strategy=None,\n", + "fix_position_embedding=False,\n", + "fp16=True,\n", + "fp16_backend=auto,\n", + "fp16_full_eval=False,\n", + "fp16_opt_level=O1,\n", + "fsdp=[],\n", + "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},\n", + "fsdp_min_num_params=0,\n", + "fsdp_transformer_layer_cls_to_wrap=None,\n", + "full_determinism=False,\n", + "gradient_accumulation_steps=1,\n", + "gradient_checkpointing=True,\n", + "gradient_checkpointing_kwargs=None,\n", + "greater_is_better=None,\n", + "group_by_length=False,\n", + "half_precision_backend=auto,\n", + "hub_always_push=False,\n", + "hub_model_id=None,\n", + "hub_private_repo=False,\n", + "hub_strategy=HubStrategy.EVERY_SAVE,\n", + "hub_token=,\n", + "ignore_data_skip=False,\n", + "include_inputs_for_metrics=False,\n", + "include_num_input_tokens_seen=False,\n", + "include_tokens_per_second=False,\n", + "jit_mode_eval=False,\n", + "kd_loss_type=kl_div,\n", + "label_names=None,\n", + "label_smoothing_factor=0.0,\n", + "learning_rate=1e-05,\n", + "length_column_name=length,\n", + "load_best_model_at_end=False,\n", + "local_rank=0,\n", + "log_level=passive,\n", + "log_level_replica=warning,\n", + "log_on_each_node=True,\n", + "logging_dir=./test_encoder_only_base_bge-large-en-v1.5/runs/Dec23_06-27-30_job-40fb0ce3-8bfb-46ea-b409-0a2e2a1a3163-master-0,\n", + "logging_first_step=False,\n", + "logging_nan_inf_filter=True,\n", + "logging_steps=1.0,\n", + "logging_strategy=IntervalStrategy.STEPS,\n", + "lr_scheduler_kwargs={},\n", + "lr_scheduler_type=SchedulerType.LINEAR,\n", + "max_grad_norm=1.0,\n", + "max_steps=-1,\n", + "metric_for_best_model=None,\n", + "mp_parameters=,\n", + "neftune_noise_alpha=None,\n", + "negatives_cross_device=True,\n", + "no_cuda=False,\n", + "normalize_embeddings=True,\n", + "num_train_epochs=2.0,\n", + "optim=OptimizerNames.ADAMW_TORCH,\n", + "optim_args=None,\n", + "optim_target_modules=None,\n", + "output_dir=./test_encoder_only_base_bge-large-en-v1.5,\n", + "overwrite_output_dir=True,\n", + "past_index=-1,\n", + "per_device_eval_batch_size=8,\n", + "per_device_train_batch_size=2,\n", + "prediction_loss_only=False,\n", + "push_to_hub=False,\n", + "push_to_hub_model_id=None,\n", + "push_to_hub_organization=None,\n", + "push_to_hub_token=,\n", + "ray_scope=last,\n", + "remove_unused_columns=True,\n", + "report_to=[],\n", + "restore_callback_states_from_checkpoint=False,\n", + "resume_from_checkpoint=None,\n", + "run_name=./test_encoder_only_base_bge-large-en-v1.5,\n", + "save_on_each_node=False,\n", + "save_only_model=False,\n", + "save_safetensors=True,\n", + "save_steps=1000,\n", + "save_strategy=IntervalStrategy.STEPS,\n", + "save_total_limit=None,\n", + "seed=42,\n", + "sentence_pooling_method=cls,\n", + "skip_memory_metrics=True,\n", + "split_batches=None,\n", + "sub_batch_size=None,\n", + "temperature=0.02,\n", + "tf32=None,\n", + "torch_compile=False,\n", + "torch_compile_backend=None,\n", + "torch_compile_mode=None,\n", + "torch_empty_cache_steps=None,\n", + "torchdynamo=None,\n", + "tpu_metrics_debug=False,\n", + "tpu_num_cores=None,\n", + "use_cpu=False,\n", + "use_ipex=False,\n", + "use_legacy_prediction_loop=False,\n", + "use_mps_device=False,\n", + "warmup_ratio=0.1,\n", + "warmup_steps=0,\n", + "weight_decay=0.0,\n", + ")\n", + "12/23/2024 06:27:40 - INFO - FlagEmbedding.abc.finetune.embedder.AbsRunner - Model parameters AbsEmbedderModelArguments(model_name_or_path='BAAI/bge-large-en-v1.5', config_name=None, tokenizer_name=None, cache_dir='./cache/model', trust_remote_code=False, token=None)\n", + "12/23/2024 06:27:40 - INFO - FlagEmbedding.abc.finetune.embedder.AbsRunner - Data parameters AbsEmbedderDataArguments(train_data=['./ft_data/training.json'], cache_path='./cache/data', train_group_size=8, query_max_len=512, passage_max_len=512, pad_to_multiple_of=8, max_example_num_per_dataset=100000000, query_instruction_for_retrieval='Represent this sentence for searching relevant passages: ', query_instruction_format='{}{}', knowledge_distillation=False, passage_instruction_for_retrieval=None, passage_instruction_format='{}{}', shuffle_ratio=0.0, same_dataset_within_batch=False, small_threshold=0, drop_threshold=0)\n", + "12/23/2024 06:27:40 - WARNING - FlagEmbedding.abc.finetune.embedder.AbsRunner - Process rank: 1, device: cuda:1, n_gpu: 1, distributed training: True, 16-bits training: True\n", + "12/23/2024 06:35:01 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.runner - Config: BertConfig {\n", + " \"_name_or_path\": \"BAAI/bge-large-en-v1.5\",\n", + " \"architectures\": [\n", + " \"BertModel\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"classifier_dropout\": null,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 1024,\n", + " \"id2label\": {\n", + " \"0\": \"LABEL_0\"\n", + " },\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 4096,\n", + " \"label2id\": {\n", + " \"LABEL_0\": 0\n", + " },\n", + " \"layer_norm_eps\": 1e-12,\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"bert\",\n", + " \"num_attention_heads\": 16,\n", + " \"num_hidden_layers\": 24,\n", + " \"pad_token_id\": 0,\n", + " \"position_embedding_type\": \"absolute\",\n", + " \"torch_dtype\": \"float32\",\n", + " \"transformers_version\": \"4.44.2\",\n", + " \"type_vocab_size\": 2,\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 30522\n", + "}\n", + "\n", + "12/23/2024 06:35:01 - INFO - FlagEmbedding.abc.finetune.embedder.AbsDataset - loading data from ./ft_data/training.json ...\n", + "Generating train split: 6300 examples [00:00, 46043.95 examples/s]\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations\n", + " warnings.warn(\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations\n", + " warnings.warn(\n", + "12/23/2024 06:35:02 - WARNING - accelerate.utils.other - Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1734935704.354551] [job-40fb0ce3-8bfb-46ea-b409-0a2e2a1a3163-master-0:1362491:f] vfs_fuse.c:281 UCX ERROR inotify_add_watch(/tmp) failed: No space left on device\n", + "[1734935704.383634] [job-40fb0ce3-8bfb-46ea-b409-0a2e2a1a3163-master-0:1362492:f] vfs_fuse.c:281 UCX ERROR inotify_add_watch(/tmp) failed: No space left on device\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...\n", + "Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/fused_adam/build.ninja...\n", + "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n", + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n", + " warnings.warn(\n", + "Building extension module fused_adam...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading extension module fused_adam...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time to load fused_adam op: 1.1966907978057861 seconds\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading extension module fused_adam...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time to load fused_adam op: 1.2037739753723145 seconds\n", + "[2024-12-23 06:35:06,883] [WARNING] [lr_schedules.py:683:get_lr] Attempting to get learning rate from scheduler before it has started\n", + "[2024-12-23 06:35:06,888] [WARNING] [lr_schedules.py:683:get_lr] Attempting to get learning rate from scheduler before it has started\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", + " 0%| | 0/3150 [00:00`_ for the inference, evaluation, and fine-tuning of BGE series models. -Besides that, there are abundant resources of `tutorials `_ and `examples `_ for users to quickly get a hands-on experience. +Besides that, there are abundant resources of and for users to quickly get a hands-on experience. .. figure:: https://raw.githubusercontent.com/FlagOpen/FlagEmbedding/refs/heads/master/imgs/projects.png :width: 700 @@ -10,4 +10,8 @@ Besides that, there are abundant resources of `tutorials `_ -Our repository provides well-structured contents \ No newline at end of file +Our repository provides well-structured contents for information retrieval and RAG: + +- The core `APIs <../API>`_ for embedding models' inference, evaluation, and fine-tuning. +- Hands-on `examples `_ for the three mentioned use cases. +- Detailed `tutorials `_ covering topics in retrieval to help you learn from scratch. \ No newline at end of file