From 4192036725028745c6712656a47651d16b69f896 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 May 2024 18:20:16 +0000 Subject: [PATCH] Release --- LICENSE | 201 ++ README.md | 497 ++++ example/7B.yaml | 30 + finetune/__init__.py | 0 finetune/args.py | 116 + finetune/checkpointing.py | 246 ++ finetune/data/__init__.py | 0 finetune/data/args.py | 61 + finetune/data/data_loader.py | 126 + finetune/data/dataset.py | 475 ++++ finetune/data/exceptions.py | 56 + finetune/data/tokenize.py | 355 +++ finetune/distributed.py | 59 + finetune/eval.py | 77 + finetune/loss.py | 16 + finetune/mixed_precision.py | 47 + finetune/monitoring/__init__.py | 0 finetune/monitoring/metrics_logger.py | 226 ++ finetune/monitoring/utils.py | 34 + finetune/utils.py | 83 + finetune/wrapped_model.py | 227 ++ model/__init__.py | 0 model/args.py | 39 + model/lora.py | 96 + model/moe.py | 31 + model/rope.py | 27 + model/transformer.py | 230 ++ pyproject.toml | 51 + requirements.dev.txt | 5 + requirements.txt | 11 + tests/__init__.py | 0 tests/fixtures | 1 + tests/test_checkpointer.py | 212 ++ tests/test_data.py | 729 ++++++ tests/test_mixed_precision.py | 136 ++ tests/test_model.py | 550 +++++ tests/test_train_loop.py | 93 + tests/test_utils.py | 227 ++ train.py | 323 +++ tutorials/mistral_finetune_7b.ipynb | 3135 +++++++++++++++++++++++++ utils/__init__.py | 0 utils/extend_model_vocab.py | 144 ++ utils/merge_lora.py | 90 + utils/reformat_data.py | 88 + utils/reformat_data_glaive.py | 127 + utils/validate_data.py | 366 +++ 46 files changed, 9643 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 example/7B.yaml create mode 100644 finetune/__init__.py create mode 100644 finetune/args.py create mode 100644 finetune/checkpointing.py create mode 100644 finetune/data/__init__.py create mode 100644 finetune/data/args.py create mode 100644 finetune/data/data_loader.py create mode 100644 finetune/data/dataset.py create mode 100644 finetune/data/exceptions.py create mode 100644 finetune/data/tokenize.py create mode 100644 finetune/distributed.py create mode 100644 finetune/eval.py create mode 100644 finetune/loss.py create mode 100644 finetune/mixed_precision.py create mode 100644 finetune/monitoring/__init__.py create mode 100644 finetune/monitoring/metrics_logger.py create mode 100644 finetune/monitoring/utils.py create mode 100644 finetune/utils.py create mode 100644 finetune/wrapped_model.py create mode 100644 model/__init__.py create mode 100644 model/args.py create mode 100644 model/lora.py create mode 100644 model/moe.py create mode 100644 model/rope.py create mode 100644 model/transformer.py create mode 100644 pyproject.toml create mode 100644 requirements.dev.txt create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 120000 tests/fixtures create mode 100644 tests/test_checkpointer.py create mode 100644 tests/test_data.py create mode 100644 tests/test_mixed_precision.py create mode 100644 tests/test_model.py create mode 100644 tests/test_train_loop.py create mode 100644 tests/test_utils.py create mode 100644 train.py create mode 100644 tutorials/mistral_finetune_7b.ipynb create mode 100644 utils/__init__.py create mode 100644 utils/extend_model_vocab.py create mode 100644 utils/merge_lora.py create mode 100755 utils/reformat_data.py create mode 100644 utils/reformat_data_glaive.py create mode 100644 utils/validate_data.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5bec4bd --- /dev/null +++ b/README.md @@ -0,0 +1,497 @@ +# Mistral-finetune + + + Open In Colab + + + +`mistral-finetune` is a light-weight codebase that enables memory-efficient and performant finetuning of Mistral's models. +It is based on [LoRA](https://arxiv.org/abs/2106.09685), a training paradigm where most weights are frozen and only 1-2% additional weights in the form of low-rank matrix perturbations are trained. + +For maximum efficiency it is recommended to use a A100 or H100 GPU. The codebase is optimized +for multi-GPU-single-node training setups, but for smaller models, such as the 7B a single GPU suffices. + +> **Note** +> +> - The goal of this repository is to provide a simple, guided entrypoint to finetune Mistral models. +> As such, it is fairly opinionated (especially around data formatting) and does not aim at being exhaustive +> across multiple model architecture or hardware types. +> For more generic approaches, you can check out some other great projects like +> [torchtune](https://pytorch.org/torchtune/stable/overview.html). + +## Installation + +To get started with Mistral LoRA fine-tuning, follow these steps: + +1. Clone this repository: +``` +cd $HOME && git clone https://github.com/mistralai/mistral-finetune.git +``` + +2. Install all required dependencies: +``` +cd mistral-finetune +pip install -r requirements.txt +``` + +## Model download + +We recommend fine-tuning one of the official Mistral models which you can download here: + +| Model | Link | Checksum | +|----------------|---------------------------------------------------------------------------------------------------------|-----------------------------------| +| 7B Base V3 | [7B Base](https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar) | `0663b293810d7571dad25dae2f2a5806`| +| 7B Instruct v3 | [7B Instruct v3](https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar) | `80b71fcb6416085bcb4efad86dfb4d52`| +| 8x7B Base V1 | [8x7B Base](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | (HF link) | +| 8x7B Instruct V1 | [8x7B Instruct](https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar) | `8e2d3930145dc43d3084396f49d38a3f` | +| 8x22 Instruct V3 | [8x22 Instruct](https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-Instruct-v0.3.tar) | `471a02a6902706a2f1e44a693813855b`| +| 8x22B Base V3 | [8x22B Base](https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-v0.3.tar) | `a2fa75117174f87d1197e3a4eb50371a`| + +**Important Notice**: For 8x7B Base V1 and 8x7B Instruct V1, it is necessary to use our v3 tokenizer and extend the vocabulary size to 32768 prior to fine-tuning. For detailed instructions on this process, please refer to the ["Model extension"](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#model-extension) section. + +E.g., to download the 7B-base model you can run the following command: +```sh +mkdir -p ~/${HOME}/mistral_models +cd ${HOME} && wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar +tar -xf mistral-7B-v0.3.tar -C mistral_models +``` + +Make sure to modify your training script and add the path to the downloaded +folder as `model_id_or_path`. + +E.g., modify [example/7B.yaml](https://github.com/mistralai/mistral-finetune/blob/main/example/7B.yaml) to include the absolute path to `$HOME/mistral_models/7B`: + +``` +model_id_or_path: "/Users/johndoe/mistral_models/7B" +``` + +## Prepare dataset + +To ensure effective training, `mistral-finetune` has strict +requirements for how the training data has to be formatted. + +All data files must be stored in jsonl format files. + +You can build two types of data files: + +### _Pretrain_: + +Pretrain data correpsodns to plain text data stored in the `"text"` key. E.g: + +```jsonl +{"text": "Text contained in document n°1"} +{"text": "Text contained in document n°2"} +``` + +### _Instruct_: + +Currently two different types of instruction following data are supported: + +- _Instruct_: conversational data stored in the `"messages"` key in the form of a list. Each list item is a dictionary containing the `"content"` and `"role"` keys. `"role"` is a string being one of "user", "assistant" or "system_prompt". The loss will only be computed if "role" == "assistant". E.g.: + +```jsonl +{ + "messages": [ + { + "role": "user", + "content": "User interaction n°1 contained in document n°1" + }, + { + "role": "assistant", + "content": "Bot interaction n°1 contained in document n°1" + }, + { + "role": "user", + "content": "User interaction n°2 contained in document n°1" + }, + { + "role": "assistant", + "content": "Bot interaction n°2 contained in document n°1" + } + ] +} +{ + "messages": [ + { + "role": "user", + "content": "User interaction n°1 contained in document n°2" + }, + { + "role": "assistant", + "content": "Bot interaction n°1 contained in document n°2" + }, + { + "role": "user", + "content": "User interaction n°2 contained in document n°2" + }, + { + "role": "assistant", + "content": "Bot interaction n°2 contained in document n°2", + "weight": 0, # don't train on n°2 + }, + { + "role": "user", + "content": "User interaction n°3 contained in document n°2" + }, + { + "role": "assistant", + "content": "Bot interaction n°3 contained in document n°2" + } + ] +} +``` + +- _Function calling_: conversational data stored in the `"messages"` key in the form of a list. Each list item is a dictionary containing the `"role"` and `"role"` keys. `"role"` is a string being one of "user", "assistant", "system_prompt", or "tool". The loss will only be computed if "role" == "assistant". + +**Note**: In function calling the `"id"` of `"tool_calls"` and the `"tool_call_id"` are randomly generated strings of exactly 9 chars. We recommend to generate this automatically +in a data preparation script as is done [here](https://github.com/mistralai/mistral-finetune/blob/a612d665897a28a3a7c44c98f88e1f7775f81a1d/utils/reformat_data_glaive.py#L74). + +E.g.: + +```jsonl +{ + "messages": [ + { + "role": "system", + "content": "You are an helpful assistant who has access to the following functions to help the user, you can use the functions if needed" + }, + { + "role": "user", + "content": "Can you help me generate an anagram of the word \"listen\"?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "TX92Jm8Zi", + "type": "function", + "function": { + "name": "generate_anagram", + "arguments": "{\"word\": \"listen\"}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"anagram\": \"silent\"}", + "tool_call_id": "TX92Jm8Zi" + }, + { + "role": "assistant", + "content": "The anagram of the word \"listen\" is \"silent\"." + }, + { + "role": "user", + "content": "That's amazing! Can you generate an anagram for the word \"race\"?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "3XhQnxLsT", + "type": "function", + "function": { + "name": "generate_anagram", + "arguments": "{\"word\": \"race\"}" + } + } + ] + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "generate_anagram", + "description": "Generate an anagram of a given word", + "parameters": { + "type": "object", + "properties": { + "word": { + "type": "string", + "description": "The word to generate an anagram of" + } + }, + "required": [ + "word" + ] + } + } + } + ] +} +``` + +## Verify dataset + +Before starting a training run you should verify that your dataset is correctly formatted and get an +estimation of the training time. You can do so by using the [./utils/validate_data](https://github.com/mistralai/mistral-finetune/blob/main/utils/validate_data.py) script. + +Note that this step is crucial to ensure that the data is correctly formatted. + +### Instruction following + +Let's go over a simple example to train a model in instruction following: + +- 1. **Load a chunk of [Ultachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)** + +Create the data folder and navigate to the folder. +```sh +cd $HOME && mkdir -p data && cd $HOME/data +``` + +Load the data into a Pandas Dataframe. + +**Note**: Make sure to have pandas and pyarrow installed (`pip install pandas pyarrow`). + +```py +import pandas as pd + +df = pd.read_parquet('https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/data/test_gen-00000-of-00001-3d4cd8309148a71f.parquet') +``` +- 2. Split into train and eval + +```py +df_train=df.sample(frac=0.95,random_state=200) +df_eval=df.drop(df_train.index) +``` + +- 3. Save data to jsonl + +```py +df_train.to_json("ultrachat_chunk_train.jsonl", orient="records", lines=True) +df_eval.to_json("ultrachat_chunk_eval.jsonl", orient="records", lines=True) +``` + +- 4. Modify your training yaml to include the ultrachat dataset and verify the yaml + +Modify [example/7B.yaml](https://github.com/mistralai/mistral-finetune/blob/main/example/7B.yaml) to include the absolute path to `$HOME/data/ultrachat_chunk_train.jsonl` as well as a dataset mixing weight for training and `$HOME/data/ultrachat_chunk_eval.jsonl` for eval, *e.g.* + +``` +data: + instruct_data: "/Users/johndoe/data/ultrachat_chunk_train.jsonl" + eval_instruct_data: "/Users/johndoe/data/ultrachat_chunk_eval.jsonl" +``` + +Now you can verify your training yaml to make sure the data is correctly formatted and to get an estimate of your training time. + +``` +cd $HOME/mistral-finetune +python -m utils.validate_data --train_yaml example/7B.yaml +``` + +Upon completion you should see an error report with many of the following errors: + +``` +The data in line 1412 of dataset /mnt/vast/home/patrick/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user +The data in line 1413 of dataset /mnt/vast/home/patrick/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user +The data in line 1414 of dataset /mnt/vast/home/patrick/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user +The data in line 1415 of dataset /mnt/vast/home/patrick/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user +``` + +Many conversations seem to end with the 'user' role which is unnecessary as we only train on 'assistant' messages and would thus would unnecessarily process data. + +You can make use of [./utils/reformat_data.py](https://github.com/mistralai/mistral-finetune/blob/main/utils/reformat_data.py) to correct the data: + +``` +cd $HOME/mistral-finetune +python -m utils.reformat_data $HOME/data/ultrachat_chunk_train.jsonl +python -m utils.reformat_data $HOME/data/ultrachat_chunk_eval.jsonl +``` + +You should see that a couple of samples will be skipped. + +- 5. Potentially change number of training steps + +Upon correction of the dataset, run the script again + +``` +cd $HOME/mistral-finetune +python -m utils.validate_data --train_yaml example/7B.yaml +``` + +You should get a summary of the data input and training parameters: + +``` +Train States + -------------------- +{ + "expected": { + "eta": "00:52:44", + "data_tokens": 25169147, + "train_tokens": 131072000, + "epochs": "5.21", + "max_steps": 500, + "data_tokens_per_dataset": { + "/mnt/vast/home/patrick/data/ultrachat_chunk_train.jsonl": "25169147.0" + }, + "train_tokens_per_dataset": { + "/mnt/vast/home/patrick/data/ultrachat_chunk_train.jsonl": "131072000.0" + }, + "epochs_per_dataset": { + "/mnt/vast/home/patrick/data/ultrachat_chunk_train.jsonl": "5.2" + } + }, +} +``` + +Having `max_steps` set to 500 would lead to iterating through the dataset roughly 5 times which is reasonable, but might +be a bit too much. A recommended setting is shown below which would only take 30min on a 8xH100 cluster. + +### Function calling + +Next let's go over a more advanced use case to fine-tune a model on function calling. +Function calling requires the data to be in the format as [explained above](#instruct). Let's go over an example. + +- 1. **Load a chat-formatted version of the [Glaive function calling dataset](https://huggingface.co/datasets/Locutusque/function-calling-chatml)** + +Create the data folder and navigate to the folder. +```sh +cd $HOME && mkdir -p data && cd $HOME/data +``` + +Load the data into a Pandas Dataframe. + +**Note**: Make sure to have pandas and pyarrow installed (`pip install pandas pyarrow`). + +```py +import pandas as pd + +df = pd.read_parquet('https://huggingface.co/datasets/Locutusque/function-calling-chatml/resolve/main/data/train-00000-of-00001-f0b56c6983b4a78f.parquet') +``` +- 2. Split into train and eval + +```py +df_train=df.sample(frac=0.95,random_state=200) +df_eval=df.drop(df_train.index) +``` + +- 3. Save data to jsonl + +```py +df_train.to_json("glaive_train.jsonl", orient="records", lines=True) +df_eval.to_json("glaive_eval.jsonl", orient="records", lines=True) +``` + +- 4. Reformat dataset + +As one can see the dataset does not follow the required function calling format, so it will need to be reformatted. Among other things `"from"` should be renamed to `"user"` and superfluous `"\n"` characters should be removed. +For this dataset you can make use of [`./utils/reformat_data_glaive.py`](https://github.com/mistralai/mistral-finetune/blob/main/utils/reformat_data_glaive.py): + +``` +cd $HOME/mistral-finetune +python -m utils.reformat_data_glaive $HOME/data/glaive_train.jsonl +python -m utils.reformat_data_glaive $HOME/data/glaive_eval.jsonl +``` + +Running this command will make sure that most samples are in the correct format. + +**Note**: It is impossible to write reformatting scripts that work for all kinds of datasets. +If you have datasets that don't yet follow the required format above, you will most probably have to +create a reformatting script yourself (mistral-chat or chat-gpt is your best friend here!). + +- 5. Validate dataset + +You can now validate the dataset by setting `data.instruct_data` and `data.eval_instruct_data` to +`$HOME/data/glaive_train.jsonl` and `$HOME/data/glaive_eval.jsonl` in `example/7B.yaml` respectively. + +The reformatted datasets still has some errors which can be removed with `--create_corrected`. For this, make sure to add +`--create_corrected` as follows: + +``` +cd $HOME/mistral-finetune +python -m utils.validate_data --train_yaml example/7B.yaml --create_corrected +``` + +Running this command will show a couple of errors and save two new datasets `$HOME/data/glaive_train.jsonl.corrected` and `$HOME/data/glaive_eval.jsonl.corrected`. Make sure to use these two dataset in `example/7B.yaml` and run the command again. Now the dataset should be correctly formatted! + + +## Start training + +Having followed the [dataset verification section](#verify-dataset), we can now start training. +For faster training, we recommend setting max_steps to only 300. Make sure to define `run_dir` to your experiment folder and optionally set `wandb_project` to a Weights & Biases project for logging`, *e.g.*: +``` +max_steps: 300 +run_dir: "/Users/johndoe/ultra_chat_test" +wandb.project: ultra_chat +``` + +Optionally you can also set `wandb` + +Save the training configuration and start training! Make sure to set `--nproc-per-node` to the number of available GPUs. + +``` +cd $HOME/mistral-finetune +torchrun --nproc-per-node 8 --master_port $RANDOM -m train example/7B.yaml +``` + +Training on ultra-chat should take around 30min on a 8xH100 node and the resulting weights should give an MT Bench score around 6.3. + +Training on glaive should take around 1h on a 8xH100 node and the resulting weights should work nicely for function calling. + +## Customizing training configuration + +The example `mistral-finetune/examples/7B` defines reasonable parameters for learning rate, weight decay, etc... but you are advised to +customize these settings for your use case. + +Generally, a training configuration should fill the following parameters: + +- `model_id_or_path` defines the model to start training from. This can be a path to a pre-trained model or a local model directory. +- `run_dir` defines the directory where training checkpoints and metrics are stored. +- `seq_len` defines the sequence length for training. This is the maximum length of input sequences the model will process. Samples are packed to reach a length of `seq_len` for maximum training efficiency. +- `batch_size` defines the number of training examples used per GPU. **Note**: The overall effective batch_size (in tokens) across all GPUs equals `num_gpus` x `batch_size` x `seq_len`. +- `max_steps` defines the maximum number of training steps. This is the total number of iterations the training process will run. It can be adjusted based on the specific needs of your training scenario. Total number of tokens seen during training is `max_steps` x `num_gpus` x `batch_size` x `seq_len`. +- `optim.lr` defines the learning rate. This is the initial learning rate for the optimizer. +- `optim.weight_decay` defines weight decay. Weight decay is a regularization technique used to prevent overfitting by penalizing large weights. We recommend leaving it at 0.1. +- `optim.pct_start` defines the percentage of the total training steps used for the learning rate warm-up phase before it starts to decrease. It corresponds to pct_start of PyTorch's OneCycleLR. +- `lora.rank` defines the size of the LoRA (Low-Rank Adaptation) adapters. We recommend 64 or less, which adjusts the rank of the low-rank decomposition used in LoRA. +- `seed` defines the random seed for initialization and data shuffling/sampling. Setting a seed ensures reproducibility of results. +- `log_freq` defines the logging frequency. This specifies how often (in steps) to log training metrics. +- `data.instruct_data` is the path to the instruction data used for training. This field has to be filled with one or multiple data sources in the format as explained above. Each data source should either be a path to jsonl file of a path to a directory containing jsonl files followed by a weighting to define the importance of this dataset: `:`. E.g.: `data.instruct_data: "/path/to/data1.jsonl:5.,/path/to/data2.jsonl:1.,/path/to/dir_of_jsonls:1."` +- `data.data` is an optional path to additional pretraining data in the format as explained above. Note that this field can be left blank. +- `data.eval_instruct_data` is an optional path to evaluation instruction data to run cross-validation at every `eval_freq` steps. Cross-validation metrics are displayed as `loss` and `perplexity`. +- `eval_freq` defines how often (in steps) to evaluate the model. This specifies the interval at which the model is evaluated on the validation set. +- `no_eval` is a flag to enable or disable intermediate evaluation. Setting it to False enables periodic evaluation during training. +- `ckpt_freq` defines how often (in steps) to save checkpoints. This specifies the interval at which the model's state is saved. +- `ckpt_only_lora` defines whether to only save the trained LoRA checkpoints or whether the trained LoRA should directly be merged into the base model and saved. **Note**: When setting `ckpt_only_lora=False` make sure that you have enough CPU and GPU memory to save the full model on a single process (this is usually only possible for the 7B model). +- `wandb.key` is used to pass your Weights & Biases (wandb) API key for logging. This allows you to log training metrics to the wandb dashboard. +- `wandb.project` defines the wandb project name. This is where the training run will be logged in the wandb interface. + +## Inference + +Once your model is trained, you should try it out in inference. We recommend using [mistral-inference](https://github.com/mistralai/mistral-inference). + +Make sure to have `mistral_inference` correctly installed: +``` +pip install mistral_inference +``` + +Assuming your `lora.safetensors` is saved under `$HOME/ultra_chat_test/checkpoints/checkpoint_000300/consolidated/lora.safetensors`, you can chat with the model using `mistral_inference`, *e.g.*: + +```sh +mistral-chat /mnt/slow/runs/patrick/mistral-finetune/7B/ --max_tokens 256 --temperature 1.0 --instruct --lora_path $HOME/ultra_chat_test/checkpoints/checkpoint_000300/consolidated/lora.safetensors +``` + +## Model extension + +**Important**: Note that one can only fine-tune mistral models that are compatible with the v3 tokenizer which entails that the models have a vocabulary size of 32768 - not 32000. One can however easily extend older version of vocabulary size 32000 to have a vocabulary size of 32768 by using: +``` +python -m utils.extend_model_vocab --original_model_ckpt /folder/to/old/model --extended_model_ckpt /folder/to/extended/model +``` + +Once the extension has worked, one can fine-tune using the newly created model checkpoint in `/folder/to/extended/model`. + +## FAQ: + +> - What's the best practice of fine-tuning MoEs? + +We see a higher degree of performance variance in when fine-tuning MoE models. It's not unusual to find that fine-tuning MoEs models with different seeds can lead to a high variance in performance. We did not observe such a high variance with dense models. Therefore, we suggest running multiple instances of the same fine-tuning process on MoEs models and selecting the one that performs best. + +> - How can I determine the number of tokens used during the model training process? + +You can use the following script to find out: https://github.com/mistralai/mistral-finetune/blob/main/utils/validate_data.py. This script accepts a .yaml training file as input and returns the number of tokens the model is being trained on. + +> - What should I do if I encounter a CUDA out-of-memory error? + +One possible solution is to reduce the batch size per GPU. The batch size is equal to `seq_len` x `batch_size`. Try setting `batch_size` to 1 and reduce `seq_len`. You can define the `batch_size` and `seq_len` in the .yaml file. diff --git a/example/7B.yaml b/example/7B.yaml new file mode 100644 index 0000000..3dd4c32 --- /dev/null +++ b/example/7B.yaml @@ -0,0 +1,30 @@ +# data +data: + instruct_data: "" # Fill + data: "" # Optionally fill with pretraining data + eval_instruct_data: "" # Optionally fill + +# model +model_id_or_path: "" # Change to downloaded path +lora: + rank: 64 + +# optim +seq_len: 32768 +batch_size: 1 +max_steps: 300 +optim: + lr: 6.e-5 + weight_decay: 0.1 + pct_start: 0.05 + +# other +seed: 0 +log_freq: 1 +eval_freq: 100 +no_eval: False +ckpt_freq: 100 + +ckpt_only_lora: True # save only trained LoRA adapters. Set to `False` to merge LoRA adapter into the base model and save full fine-tuned model + +run_dir: "" # Fill diff --git a/finetune/__init__.py b/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/args.py b/finetune/args.py new file mode 100644 index 0000000..5290917 --- /dev/null +++ b/finetune/args.py @@ -0,0 +1,116 @@ +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +from simple_parsing.helpers import Serializable + +from model.args import LoraArgs + +from .data.args import DataArgs + + +@dataclass +class OptimArgs(Serializable): + lr: float = 3e-4 + weight_decay: float = 0.1 + pct_start: float = 0.3 + + +@dataclass +class WandbArgs(Serializable): + project: Optional[str] = None # Fill this argument to use wandb. + offline: bool = False + key: Optional[str] = None + run_name: Optional[str] = None + + def __post_init__(self) -> None: + if self.project is not None: + try: + import wandb # noqa: F401 + except ImportError: + raise ImportError("`wandb` not installed. Either make sure `wandb` is installed or set `wandb:project` to None.") + + if len(self.project) == 0: + raise ValueError("`wandb.project` must not be an empty string.") + +@dataclass +class MLFlowArgs(Serializable): + tracking_uri: Optional[str] = None + experiment_name: Optional[str] = None + + def __post_init__(self) -> None: + if self.tracking_uri is not None: + try: + import mlflow # noqa: F401 + except ImportError: + raise ImportError("`mlflow` not installed. Either make sure `mlflow` is installed or set `mlflow.tracking_uri` to None.") + + if self.experiment_name is None: + raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.") + + + +@dataclass +class TrainArgs(Serializable): + data: DataArgs + + # if specified, instruct_tokenizer and model will be loaded + model_id_or_path: str # Path to the directory containing the initial model or model id: "mistral-small" + + run_dir: str # Path to the directory where everything will be saved. It needs to be empty. + # Name of the wandb run, if None it will be set to the name of the run_dir. + + optim: OptimArgs = field(default_factory=OptimArgs) + seed: int = 0 + # Number of steps to accumulate gradients before calling doing an optimizer step. + num_microbatches: int = 1 + + seq_len: int = 2048 # Number of tokens per batch per device. + batch_size: int = 1 + max_norm: float = 1.0 # Gradient clipping. + max_steps: int = 100 # Number of training steps. + log_freq: int = 1 # Number of steps between each logging. + + # Number of steps between each checkpoint saving. If inferior to 1, only the last checkpoint will be saved. + ckpt_freq: int = 0 + ckpt_only_lora: bool = True + # If True, no checkpoint will be saved. This is useful for development. + no_ckpt: bool = False + num_ckpt_keep: Optional[int] = 3 + eval_freq: int = 0 + no_eval: bool = True + + # Efficiency + # Determines whether gradient checkpointing should be utilized or not during the training process. Gradient checkpointing can be beneficial in reducing memory usage at the cost of slightly longer training times. + checkpoint: bool = True + + world_size: Optional[int] = field(init=False, default=None) + + # logging + wandb: WandbArgs = field(default_factory=WandbArgs) + mlflow: MLFlowArgs = field(default_factory=MLFlowArgs) + + # LoRA + lora: Optional[LoraArgs] = field(default_factory=LoraArgs) + + def __post_init__(self) -> None: + assert getattr(self, "world_size", None) is None + self.world_size = int(os.environ.get("WORLD_SIZE", -1)) + + if self.wandb.offline: + command = f"cd {self.run_dir}; wandb sync --sync-all" + logging.info(f"to sync wandb offline, run: {command}") + + assert self.num_microbatches >= 1 + + assert self.num_ckpt_keep is None or self.num_ckpt_keep >= 1 + + if self.model_id_or_path is not None: + Path(self.model_id_or_path).exists() + + if not self.ckpt_only_lora: + logging.warning( + "You are have disabled `ckpt_only_lora` and are thus merging the trained LoRA checkpoint into the base model upon checkpointing. This might lead to OOM erros - make sure you have enough CPU and GPU memory." + ) diff --git a/finetune/checkpointing.py b/finetune/checkpointing.py new file mode 100644 index 0000000..6e118e2 --- /dev/null +++ b/finetune/checkpointing.py @@ -0,0 +1,246 @@ +import json +import logging +import shutil +from pathlib import Path +from typing import Dict, List, Optional, Union + +import safetensors.torch +import torch +from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase +from torch.distributed import barrier +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + +from model.transformer import LoRALinear + +from .distributed import get_rank, get_world_size +from .utils import TrainState + +logger = logging.getLogger("checkpointing") + + +def main_logger_info(message: str) -> None: + if get_rank() == 0: + logger.info(message) + + +class Checkpointer: + """A class to save PyTorch model and optimizer states""" + + def __init__( + self, + model: FullyShardedDataParallel, + state: TrainState, + run_dir: Union[Path, str], + optimizer: Optional[torch.optim.Optimizer] = None, + num_ckpt_keep: Optional[int] = None, + ): + self.model = model + self.optimizer = optimizer + self.state = state + self.run_dir = Path(run_dir) + self.rank = get_rank() + self.num_ckpt_keep = num_ckpt_keep + + @property + def ckpt_dir(self) -> Path: + return self.run_dir / "checkpoints" + + @property + def dst_dir(self) -> Path: + return self.ckpt_dir / f"checkpoint_{self.state.step:06d}" / "consolidated" + + @staticmethod + def consolidated_path( + ckpt_dir: Path, use_safetensors: bool, save_only_lora: Optional[bool] = False + ) -> Path: + suffix = "safetensors" if use_safetensors else "00.pth" + prefix = "lora" if save_only_lora else "consolidated" + + return ckpt_dir / f"{prefix}.{suffix}" + + @staticmethod + def _tmp(ckpt_dir: Path) -> Path: + return ckpt_dir.with_name(f"tmp.{ckpt_dir.name}") + + def write_params_info(self, tmp_dst: Path): + params_path = tmp_dst / "params.json" + with open(params_path, "w") as f: + model_args = self.model.args.to_dict() + + f.write(json.dumps(model_args, indent=4)) + + def delete_old_ckpts(self) -> List[Path]: + all_saved_ckpts = [d for d in self.ckpt_dir.iterdir() if d.is_dir()] + + # Sort directories by creation time (oldest to newest) + all_saved_ckpts.sort(key=lambda x: x.stat().st_ctime, reverse=True) + + ckpts_to_delete = all_saved_ckpts[self.num_ckpt_keep :] + + for ckpt_to_delete in ckpts_to_delete: + try: + shutil.rmtree(ckpt_to_delete) + main_logger_info(f"Deleted ckpt: {ckpt_to_delete}") + except OSError as e: + main_logger_info(f"Error deleting directory {ckpt_to_delete}: {e}") + + return ckpts_to_delete + + @staticmethod + def get_lora_states(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: v for k, v in state_dict.items() if "lora" in k} + + @staticmethod + def get_non_lora_states( + state_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + return { + k: v + for k, v in state_dict.items() + if not any(l_key in k for l_key in ["lora", "frozen"]) + } + + @torch.no_grad() + def retrieve_save_states( + self, save_only_lora: bool, save_dtype: torch.dtype + ) -> Dict[str, torch.Tensor]: + if save_only_lora: + assert ( + self.model.args.lora.enable + ), "Cannot save LoRA checkpoint as LoRA training is not enabled." + + # remove all potential hooks + for module in self.model.modules(): + if isinstance(module, LoRALinear) and hasattr(module, "_merge_lora_handle"): + module._merge_lora_handle.remove() # type: ignore + + # merge weights if we don't just save LoRA + if not save_only_lora: + + def merge_lora( + m: torch.nn.Module, + destination: Dict[str, torch.Tensor], + prefix: str, + *args, + ): + weight = m.merge_weight() # type: ignore + destination[prefix + "weight"] = weight + + for module in self.model.modules(): + if isinstance(module, LoRALinear): + module._merge_lora_handle = module._register_state_dict_hook( + merge_lora + ) + + offload_to_cpu = get_world_size() > 1 + if save_only_lora: + + def is_trainable_fsdp( + module: Union[torch.nn.Module, FullyShardedDataParallel] + ): + is_fsdp = isinstance(module, FullyShardedDataParallel) + all_params_have_grads = is_fsdp and all( + p.requires_grad is True for p in module.parameters() + ) + + # need to make sure only lowest fsdp wrap is used + is_leaf_node = is_fsdp and len(list(module.module.children())) == 0 # type: ignore + + return is_fsdp and all_params_have_grads and is_leaf_node + + # extract all modules with only trainable weights + modules = { + k: m for k, m in self.model.named_modules() if is_trainable_fsdp(m) + } + + states = {} + for key, module in modules.items(): + assert isinstance( + module, FullyShardedDataParallel + ), "`module` should be an instance of `FullyShardedDataParallel`" + parent_prefix = key.replace("_fsdp_wrapped_module.", "").replace( + "_checkpoint_wrapped_module.", "" + ) + with module.summon_full_params( + module, writeback=True, offload_to_cpu=offload_to_cpu + ): + states.update( + { + f"{parent_prefix}.{k}": v.to(dtype=save_dtype) + for k, v in module.state_dict().items() + } + ) + else: + # make sure you have enough CPU RAM available to save the full model + assert isinstance( + self.model, FullyShardedDataParallel + ), "`self.model` should be an instance of `FullyShardedDataParallel`" + with self.model.summon_full_params( + self.model, writeback=True, offload_to_cpu=offload_to_cpu + ): + states = self.get_non_lora_states(self.model.state_dict()) + states = {k: v.to(dtype=save_dtype) for k, v in states.items()} + + states = dict(sorted(states.items())) + return states + + @staticmethod + def save_tokenizer(instruct_tokenizer: InstructTokenizerBase, tmp_dst: Path): + serialized_spm = instruct_tokenizer.tokenizer._model.serialized_model_proto() # type: ignore + + tokenizer_path = tmp_dst / "tokenizer.model.v3" + + with open(tokenizer_path, "wb") as f: + f.write(serialized_spm) + + @torch.no_grad() + def save_checkpoint( + self, + save_only_lora: bool, + dtype: torch.dtype = torch.float16, + instruct_tokenizer: Optional[InstructTokenizerBase] = None, + ): + tmp_dst = self._tmp(self.dst_dir) + main_logger_info( + f"Dumping checkpoint in {self.dst_dir} using tmp name: {tmp_dst.name}" + ) + + assert not self.dst_dir.exists(), f"dst exists {self.dst_dir}" + tmp_dst.mkdir(parents=True, exist_ok=True) + + states: Dict[str, torch.Tensor] = self.retrieve_save_states( + save_only_lora, dtype + ) + + barrier() + + if self.rank == 0: + # save checkpoint in tmp path + safetensors.torch.save_file( + states, + self.consolidated_path( + tmp_dst, use_safetensors=True, save_only_lora=save_only_lora + ), # always use safetensors for checkpointing + ) + + self.write_params_info(tmp_dst) + + # save tokenizer + if instruct_tokenizer is not None: + self.save_tokenizer(instruct_tokenizer, tmp_dst) + + assert not self.dst_dir.exists(), f"should not happen! {self.dst_dir}" + tmp_dst.rename(self.dst_dir) + + logger.info( + f"Done dumping checkpoint in {self.dst_dir} for step: {self.state.step}" + ) + + # delete last n checkpoints + if self.num_ckpt_keep is not None: + ckpts_to_delete = self.delete_old_ckpts() + logger.info( + f"Done deleting checkpoints {', '.join([str(c) for c in ckpts_to_delete])}" + ) + + main_logger_info("Done!") diff --git a/finetune/data/__init__.py b/finetune/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/data/args.py b/finetune/data/args.py new file mode 100644 index 0000000..dd05cc3 --- /dev/null +++ b/finetune/data/args.py @@ -0,0 +1,61 @@ +import logging +from dataclasses import dataclass, field + +from simple_parsing.helpers import Serializable + +logger = logging.getLogger("data") + + +@dataclass() +class InstructArgs(Serializable): + shuffle: bool = True + + # For function calling training examples only the last tool call + # of the assistant message can be used for training. Therefore, + # we chunk longer function calling conversations into multiple + # training samples to not loose any data. E.g.: + # [[ + # UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1 + # UserMessage_2, AssisantToolCallMessage_2, ToolMessage_2, AssisantMessage_2 + # ]] + # => is chunked into two training samples: + # [[ + # UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1 + # ], + # [ + # UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1 + # UserMessage_2, AssisantToolCallMessage_2, ToolMessage_2, AssisantMessage_2 + # ]] + # NOTE: Only if your data is already pre-chunked should this argument be set to False + dynamic_chunk_fn_call: bool = True + + +@dataclass() +class DataArgs(Serializable): + # The data arguments `data` and `instruct_data` are a string in the format + # "data_source_dir_1:weight_1,data_source_dir_2:weight_2,...". The weight + # will be used to sample the data sources. If the sum of the weights is + # not 1 when concatenating the two arguments `data` and `instruct_data`, + # it will be normalized. The data sources folders must contain jsonl files. + # If the value is an empty string, no data will be used for the corresponding + # data type. + data: str = ( + "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "text" key. See Readme for more details. Can be left empty. + ) + shuffle: bool = False + instruct_data: str = ( + "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty. + ) + eval_instruct_data: str = ( + "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty. + ) + instruct: InstructArgs = field(default_factory=InstructArgs) + + def __post_init__(self) -> None: + if ( + self.instruct.shuffle is False + and self.instruct.dynamic_chunk_fn_call is True + ): + raise ValueError( + "Make sure to either enable `data.instruct.shuffle=True` or `data.instruct.dynamic_chunk_fn_call=False`. Dynamic chunking is only possible if data is loaded and shuffled before training." + ) diff --git a/finetune/data/data_loader.py b/finetune/data/data_loader.py new file mode 100644 index 0000000..0c5be85 --- /dev/null +++ b/finetune/data/data_loader.py @@ -0,0 +1,126 @@ +import dataclasses +from typing import Any, Iterator, List, Optional + +import numpy as np +from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase + +from .args import DataArgs +from .dataset import build_dataset + + +@dataclasses.dataclass +class Batch: + x: np.ndarray + y: np.ndarray + sizes: List[int] + y_mask: Optional[np.ndarray] = None + is_pad_only: bool = False + + def __post_init__(self): + assert self.x.ndim == 1 + assert self.x.shape == self.y.shape + assert self.x.dtype == np.int64 + assert self.y.dtype == np.int64 + assert isinstance(self.sizes, list) + assert sum(self.sizes) == self.x.size == self.y.size + + if self.y_mask is not None: + assert self.y_mask.size == self.y.size, (self.y_mask.shape, self.y.shape) + assert self.y_mask.dtype == bool + assert sum(self.sizes) == self.y_mask.size + assert not self.y_mask.all() + assert self.y_mask.any() + + if self.is_pad_only: + assert np.sum(np.abs(self.y)) == 0 + assert np.sum(np.abs(self.x)) == 0 + assert self.y_mask is None + # create all 0's mask for pad samples + self.y_mask = np.zeros_like(self.x) + + + + +@dataclasses.dataclass +class BatchList: + x: List[List[int]] = dataclasses.field(default_factory=list) + y: List[List[int]] = dataclasses.field(default_factory=list) + sizes: List[List[int]] = dataclasses.field(default_factory=list) + y_mask: List[Optional[List[int]]] = dataclasses.field(default_factory=list) + + def __post_init__(self): + assert self.x == [], "`BatchList` has to be empty at init." + assert self.y == [], "`BatchList` has to be empty at init." + assert self.sizes == [], "`BatchList` has to be empty at init." + assert self.y_mask == [], "`BatchList` has to be empty at init." + + def __len__(self) -> int: + return len(self.x) + + def add(self, x: List[int], y: List[int], sizes: List[int], y_mask: Optional[List[int]] = None): + self.x.append(x) + self.y.append(y) + self.sizes.append(sizes) + self.y_mask.append(y_mask) + + def empty(self): + self.x = [] + self.y = [] + self.sizes = [] + self.y_mask = [] + + @staticmethod + def flatten_to_numpy(list_of_lists: List[List[Any]], dtype: np.dtype) -> np.array: + return np.array([el for sublist in list_of_lists for el in sublist], dtype=dtype) + + def create_batch(self) -> Batch: + x_np: np.array = self.flatten_to_numpy(self.x, dtype=np.int64) + y_np: np.array = self.flatten_to_numpy(self.y, dtype=np.int64) + sizes = sum(self.sizes, []) # noqa + + y_mask_np: Optional[np.array] = self.flatten_to_numpy(self.y_mask, dtype=bool) + y_mask_np = None if y_mask_np.all() else y_mask_np + + return Batch(x_np, y_np, sizes, y_mask_np) + + + + +def build_data_loader( + instruct_tokenizer: InstructTokenizerBase, + args: DataArgs, + batch_size: int, + seq_len: int, + seed: Optional[int], + rank: int, + world_size: int, + is_eval: bool, +) -> Iterator[Batch]: + pretrain_data = args.data if not is_eval else "" + instruct_data = args.instruct_data if not is_eval else args.eval_instruct_data + + dataset = build_dataset( + pretrain_data=pretrain_data, + instruct_data=instruct_data, + instruct_args=args.instruct, + instruct_tokenizer=instruct_tokenizer, + seq_len=seq_len, + seed=seed, + rank=rank, + world_size=world_size, + is_eval=is_eval, + shuffle_pretrain=args.shuffle, + ) + + batch_list = BatchList() + for sample in dataset: + assert all(s >= 0 for s in sample.sizes) + + batch_list.add(sample.x, sample.y, sample.sizes, sample.mask) + + if len(batch_list) == batch_size: + batch: Batch = batch_list.create_batch() + yield batch + + batch_list.empty() + diff --git a/finetune/data/dataset.py b/finetune/data/dataset.py new file mode 100644 index 0000000..516a97d --- /dev/null +++ b/finetune/data/dataset.py @@ -0,0 +1,475 @@ +import dataclasses +import itertools +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union + +import numpy as np +import torch.distributed as dist +from mistral_common.protocol.instruct.messages import ( + FinetuningAssistantMessage, + SystemMessage, +) +from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase + +from finetune.distributed import get_rank + +from .args import InstructArgs +from .tokenize import ( + Mask, + SampleType, + TokenSample, + TrainingInstructSample, + build_instruct_sample, + encode, +) + +logger = logging.getLogger("dataset") + + +_LOADED_DATASETS: Dict[Path, List[str]] = {} + + +def main_logger_info(message: str) -> None: + if dist.is_initialized() and get_rank() == 0: + logger.info(message) + + +def load_file(path: Path, world_size: int, rank: int) -> List[str]: + lines = [] + with path.open() as f: + for idx, line in enumerate(f): + if not idx % world_size == rank: + continue + lines.append(line) + return lines + + +def maybe_load_local_dataset( + path: Path, chunk: bool, rank: int, world_size: int, instruct_tokenizer: InstructTokenizerBase, sample_type: SampleType +) -> List[TokenSample]: + global _LOADED_DATASETS + + if path in _LOADED_DATASETS: + return _LOADED_DATASETS[path] + + main_logger_info(f"Loading {path} ...") + lines: List[str] = load_file(path, rank=rank, world_size=world_size) + + if chunk: + lines += maybe_chunk_lines(lines) + + tokens_list: List[TokenSample] = [] + for line in lines: + data = json.loads(line) + + token_sample: TokenSample = encode( + data, + instruct_tokenizer=instruct_tokenizer, + as_type=sample_type, + ) + tokens_list.append(token_sample) + + main_logger_info(f"{path} loaded and tokenized.") + _LOADED_DATASETS[path] = tokens_list + + return _LOADED_DATASETS[path] + + +@dataclass +class DataDir: + path: Path + sample_type: SampleType + + @property + def jsonl_files(self): + assert self.path.exists(), f"Make sure that {self.path} exists" + jsonl_files = list(self.path.rglob("*jsonl")) + assert ( + len(jsonl_files) > 0 + ), f"{self.path} does not seem to have any files ending with '.jsonl'" + return jsonl_files + + +@dataclass +class DataFile: + path: Path + sample_type: SampleType + + @property + def jsonl_files(self): + assert self.path.exists(), f"Make sure that {self.path} exists" + return [self.path] + + +def parse_data_sources( + pretrain_data: str, + instruct_data: str, +) -> Tuple[List[Union[DataDir, DataFile]], List[float]]: + seen: Set[str] = set() + sources: List[Union[DataDir, DataFile]] = [] + weights: List[float] = [] + for sample_sources, sample_type in [ + (pretrain_data, SampleType.PRETRAIN), + (instruct_data, SampleType.INSTRUCT), + ]: + for source in sample_sources.strip().split(","): + if not source: + continue + + source_items = source.strip().split(":") + if len(source_items) == 1: + path_ = source_items[0] + weight = 1.0 + elif len(source_items) == 2: + path_, weight_ = source_items + weight = float(weight_) + else: + raise ValueError( + f"{source} is not correctly formatted. Make sure to format each data source as : or just " + ) + + assert ( + path_ not in seen + ), f"{path_} seems to be duplicated. Make sure to only add it once." + assert ( + weight > 0 + ), f"Make sure to define strictly positive data sampling weights, not {weight}" + + data: Union[DataDir, DataFile] + if Path(path_).is_dir(): + data = DataDir(path=Path(path_), sample_type=sample_type) + elif Path(path_).is_file(): + data = DataFile(path=Path(path_), sample_type=sample_type) + else: + raise FileNotFoundError( + f"The path {path_} does not exist. Make sure {path_} is either a file or directory that contains training data." + ) + + sources.append(data) + weights.append(weight) + + seen.add(path_) + + sum_weights = sum(weights) + n_weights = [weight / sum_weights for weight in weights] + + assert min(n_weights) > 0 + assert ( + abs(1 - sum(n_weights)) < 1e-8 + ), f"Defined data sampling weights {weights} must sum to 1." + return sources, n_weights + + +@dataclasses.dataclass() +class SequenceMaskAndSizes: + """ + Concatenation of samples to reach a given size + """ + + x: List[int] + y: List[int] + mask: Mask + sizes: List[int] + + def __post_init__(self): + assert sum(self.sizes) == len(self.x) == len(self.y) == len(self.mask) + + +def sequence_iterator( + ds_it: Iterator[TokenSample], + seq_len: int, + is_finite: bool, +) -> Iterator[SequenceMaskAndSizes]: + """ + Creates sequences of length `seq_len` from the dataset iterator by concatenating samples. + """ + x_buffer: List[int] = [] + y_buffer: List[int] = [] + mask_buffer: Mask = [] + + sizes: List[int] = [] + n_missing = seq_len + for sample in ds_it: + assert 0 <= len(x_buffer) < seq_len, len(x_buffer) + assert n_missing == seq_len - len( + x_buffer + ), f"n_missing: {n_missing} | seq_len - len(x_buffer) {seq_len - len(x_buffer)}" + + tokens, mask = sample.tokens, sample.masks[1:] + x, y = tokens[:-1], tokens[1:] + cur_pos = 0 + + while cur_pos < len(x): + size = len(x[cur_pos : cur_pos + n_missing]) + + curr_mask = mask[cur_pos : cur_pos + n_missing] + if not any(curr_mask): + cur_pos += size + # we have a sequence with a mask filled with False + continue + + x_buffer.extend(x[cur_pos : cur_pos + n_missing]) + y_buffer.extend(y[cur_pos : cur_pos + n_missing]) + mask_buffer.extend(curr_mask) + n_missing -= size + sizes.append(size) + + cur_pos += size + + if n_missing == 0: + assert len(mask_buffer) == len(x_buffer) == seq_len == len(y_buffer) + assert sum(sizes) == seq_len + # we don't want to yield sequences with a mask filled with False + if any(mask_buffer): + yield SequenceMaskAndSizes( + x=x_buffer, + y=y_buffer, + mask=mask_buffer, + sizes=sizes, + ) + x_buffer, y_buffer = [], [] + mask_buffer = [] + sizes = [] + n_missing = seq_len + + if is_finite: + # if dataloader is in eval, pad to seq length + if any(mask_buffer): + mask_buffer.extend(n_missing * [False]) + x_buffer.extend(n_missing * [0]) + y_buffer.extend(n_missing * [0]) + sizes.append(n_missing) + + yield SequenceMaskAndSizes( + x=x_buffer, + y=y_buffer, + mask=mask_buffer, + sizes=sizes, + ) + + +def build_dataset( + pretrain_data: str, + instruct_data: str, + instruct_args: InstructArgs, + instruct_tokenizer: InstructTokenizerBase, + seq_len: int, + seed: Optional[int], + rank: int, + world_size: int, + is_eval: bool, + shuffle_pretrain: bool = False, +) -> Iterator[SequenceMaskAndSizes]: + sources, probabilities = parse_data_sources( + pretrain_data=pretrain_data, instruct_data=instruct_data + ) + + def do_shuffle(source: Union[DataDir, DataFile]) -> bool: + shuffle = { + SampleType.PRETRAIN: shuffle_pretrain, + SampleType.INSTRUCT: instruct_args.shuffle, + }[source.sample_type] + + return not is_eval and shuffle + + dataset_iterators = [ + get_dataset_iterator( + source, + instruct_args=instruct_args, + instruct_tokenizer=instruct_tokenizer, + rank=rank, + world_size=world_size, + is_finite=is_eval, + seed=seed, + shuffle_at_epoch=do_shuffle(source), + ) + for source in sources + ] + + sequence_iterators = [ + sequence_iterator( + ds_it=it, + seq_len=seq_len, + is_finite=is_eval, + ) + for it in dataset_iterators + ] + + if is_eval: + combined_iterator = itertools.chain.from_iterable(sequence_iterators) + else: + # make sure random_seed is different per rank and original seed + random_seed = np.array((seed, rank)) + rng = np.random.RandomState(seed=random_seed) + combined_iterator = interleave_iterators( + sequence_iterators, probabilities=probabilities, rng=rng + ) + + return combined_iterator + + +def get_rng(seed: int, rank: int) -> np.random.RandomState: + random_seed = np.array((seed, rank)) + rng = np.random.RandomState(seed=random_seed) + return rng + + +def get_dataset_iterator( + source: Union[DataDir, DataFile], + instruct_args: InstructArgs, + instruct_tokenizer: InstructTokenizerBase, + rank: int, + world_size: int, + is_finite: bool, + seed: Optional[int], + shuffle_at_epoch: bool, +) -> Iterator[TokenSample]: + jsonl_files = source.jsonl_files + rng: Optional[np.random.RandomState] = ( + get_rng(seed, rank) if seed is not None else None + ) + + chunk_dataset = ( + instruct_args.dynamic_chunk_fn_call + and source.sample_type == SampleType.INSTRUCT + ) + + if not is_finite: + # train mode + while True: + for jsonl_file in jsonl_files: + if shuffle_at_epoch: + assert rng is not None, "`seed` has to be passed when shuffling" + # will preload all data into RAM, shuffle and yield + yield from preload_and_yield( + jsonl_file, + chunk_dataset=chunk_dataset, + rank=rank, + world_size=world_size, + rng=rng, + instruct_tokenizer=instruct_tokenizer, + sample_type=source.sample_type, + ) + else: + # will read data on-the-fly and yield + main_logger_info(f"Lazily loading {jsonl_file} ...") + yield from lazy_load_and_yield( + jsonl_file, + rank=rank, + world_size=world_size, + instruct_tokenizer=instruct_tokenizer, + sample_type=source.sample_type, + ) + else: + # eval mode + for jsonl_file in jsonl_files: + # No need to shuffle for eval + yield from lazy_load_and_yield( + jsonl_file, + rank=rank, + world_size=world_size, + instruct_tokenizer=instruct_tokenizer, + sample_type=source.sample_type, + ) + + +def preload_and_yield( + jsonl_file: Path, + chunk_dataset: bool, + rank: int, + world_size: int, + rng: np.random.RandomState, + instruct_tokenizer: InstructTokenizerBase, + sample_type: SampleType, +) -> Iterator[TokenSample]: + # only instruct data has to be chunked + # load dataset if not already loaded. Make sure to only load 1/world_size dataset + tokens_list = maybe_load_local_dataset( + jsonl_file, chunk=chunk_dataset, rank=rank, world_size=world_size, instruct_tokenizer=instruct_tokenizer, sample_type=sample_type + ) + + if sample_type == SampleType.PRETRAIN: + assert chunk_dataset is False, "Pretrain data should not have chunking enabled." + + main_logger_info(f"Shuffling {jsonl_file} ...") + rng.shuffle(tokens_list) + + for token_sample in tokens_list: + yield token_sample + +def lazy_load_and_yield( + jsonl_file: Path, + rank: int, + world_size: int, + instruct_tokenizer: InstructTokenizerBase, + sample_type: SampleType, +): + with jsonl_file.open() as file_handle: + for idx, line in enumerate(file_handle): + if not idx % world_size == rank: + continue + + data = json.loads(line) + yield encode( + data, + instruct_tokenizer=instruct_tokenizer, + as_type=sample_type, + ) + + +def maybe_chunk_lines(lines: List[str]) -> List[str]: + extra_lines: List[str] = [] + for line in lines: + data = json.loads(line) + # mult-turn fn call data will be chunked and shorder conversations are added additionally + maybe_chunked_lines = maybe_chunk_data(data) + extra_lines.extend([json.dumps(line) for line in maybe_chunked_lines]) + + return extra_lines + + +def maybe_chunk_data(data: Dict[str, Any]) -> List[Dict[str, Any]]: + # think about always allowing both open-ai and non-open-ai data + sample = build_instruct_sample(data) + + def num_assistant_messages(sample: TrainingInstructSample) -> int: + return len( + [m for m in sample.messages if isinstance(m, FinetuningAssistantMessage)] + ) + + chunk_data = [] + while sample.only_last is True and num_assistant_messages(sample) > 1: + assert sample == build_instruct_sample(sample.dict()) + last_message = sample.messages.pop() + + # 1. First pop until and including last assistant message + system_message = None + while not isinstance(last_message, FinetuningAssistantMessage): + last_message = sample.messages.pop() + if isinstance(last_message, SystemMessage): + system_message = last_message + + # 2. Second pop until and excluding last assistant message + prev_last_message = sample.messages[-1] + while not isinstance(prev_last_message, FinetuningAssistantMessage): + last_message = sample.messages.pop() + if isinstance(last_message, SystemMessage): + system_message = last_message + + prev_last_message = sample.messages[-1] + + # if system_message is not None, append again + if system_message is not None: + sample.messages.append(system_message) + chunk_data.append(sample.dict()) + + return chunk_data + + +def interleave_iterators(iterators: List[Iterator], probabilities, rng): + while True: + it_id = rng.choice(range(len(iterators)), p=probabilities) + yield next(iterators[it_id]) diff --git a/finetune/data/exceptions.py b/finetune/data/exceptions.py new file mode 100644 index 0000000..190c07d --- /dev/null +++ b/finetune/data/exceptions.py @@ -0,0 +1,56 @@ +class MessageFormatError(Exception): + def __init__(self, message, data): + self._message = message + self._begin_data = data[:20] + super().__init__() + + def __str__(self): + return f"A message starting with {self._begin_data} is incorrectly formated." + self._message + + +class ToolCallFormatError(Exception): + def __init__(self, message, data): + self._message = message + self._begin_data = data[:20] + super().__init__() + + def __str__(self): + return f"A tool call assistant message starting with {self._begin_data} of the conversation is incorrectly formated. " + self._message + + +class FunctionFormatError(Exception): + def __init__(self, message, data): + self._message = message + self._begin_data = data[:20] + super().__init__() + + def __str__(self): + return ( + f"A function of the conversation starting with {self._begin_data} is incorrectly formated. " + + self._message + ) + + +class ConversationFormatError(Exception): + def __init__(self, message, data): + self._message = message + self._begin_data = data[:20] + super().__init__() + + def __str__(self): + return ( + f"A conversation starting with {self._begin_data} is incorrectly formated. " + self._message + ) + + +class UnrecognizedRoleError(Exception): + def __init__(self, role, allowed_roles): + self._role = role + self._allowed_roles = allowed_roles + super().__init__() + + def __str__(self): + return ( + f"The following role: {self._role} is not recognized in line: {self.line} of the dataset {self.dataset}. Make sure that each role is one of {self._allowed_roles}" + + self._message + ) diff --git a/finetune/data/tokenize.py b/finetune/data/tokenize.py new file mode 100644 index 0000000..ee724fa --- /dev/null +++ b/finetune/data/tokenize.py @@ -0,0 +1,355 @@ +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from mistral_common.protocol.instruct.messages import ( + FinetuningAssistantMessage, + Roles, + SystemMessage, + ToolMessage, + UserMessage, +) +from mistral_common.protocol.instruct.tool_calls import ( + Function, + FunctionCall, + Tool, + ToolCall, +) +from mistral_common.protocol.instruct.validator import ( + MistralRequestValidatorV3, + ValidationMode, +) +from mistral_common.tokens.instruct.request import InstructRequest +from mistral_common.tokens.tokenizers.base import Tokenizer +from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase + +from .exceptions import ( + ConversationFormatError, + FunctionFormatError, + MessageFormatError, + ToolCallFormatError, + UnrecognizedRoleError, +) + +logger = logging.getLogger("tokenize") + +Sequence = List[int] +Mask = List[bool] + + +class TrainingInstructSample(InstructRequest): + available_tools: Optional[List[Tool]] = None + only_last: bool = False + + +@dataclass() +class TokenSample: + tokens: Sequence + masks: Mask + + +class SampleType(str, Enum): + PRETRAIN = "pretrain" + INSTRUCT = "instruct" + + +def encode( + data: Dict[str, Any], + instruct_tokenizer: InstructTokenizerBase, + as_type: SampleType, +) -> TokenSample: + sample: Union[str, TrainingInstructSample] + if as_type == SampleType.PRETRAIN: + sample = get_pretrain_sample(data) + elif as_type == SampleType.INSTRUCT: + sample = build_instruct_sample(data) + + return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer) + + +def get_pretrain_sample(data: Dict[str, Any]) -> str: + content_keys = ["text", "content"] + assert not all( + k in data for k in content_keys + ), "Make sure to have either 'text' or 'content' in your data. Not both." + assert any( + data.get(k) is not None for k in content_keys + ), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}" + + # get first non-None value + sample = None + for key in content_keys: + sample = data[key] if key in data else sample + + assert isinstance(sample, str), sample + + return sample + + +def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample: + messages: List[ + SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage + ] = [] + # optional data fields that might be set + available_tools: Optional[List[Tool]] = data.get("available_tools") + system_prompt = data.get("system_prompt") + + messages_keys = ["messages", "interactions"] + content_keys = ["content", "text"] # both are accepted + allowed_roles = [role.value for role in Roles] + + if not any(messages_key in data for messages_key in messages_keys): + err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'." + raise ConversationFormatError(err, str(data)) + + if all(messages_key in data for messages_key in messages_keys): + err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two." + raise ConversationFormatError(err, str(data)) + + # get first non-None value + data_messages: Optional[List[Dict[str, Any]]] = None + for key in messages_keys: + data_messages = data[key] if key in data else data_messages + + assert data_messages is not None, "data_messages can't be None" + + if "available_tools" in data and "tools" in data: + err = "The conversation contains both an `available_tools` and `tools` key. You can only have one." + raise ConversationFormatError(err, str(data)) + + if data.get("tools", None) is not None and len(data["tools"]) > 0: + available_tools = _parse_available_tools(data["tools"]) + elif ( + data.get("available_tools", None) is not None + and len(data["available_tools"]) > 0 + ): + available_tools = _parse_available_tools(data["available_tools"]) + + for data_message in data_messages: + is_tool_call = data_message.get("tool_calls") is not None + + if "role" not in data_message: + err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'." + raise MessageFormatError(err, str(data)) + + role = data_message["role"] + + if all(key in data_message for key in content_keys): + err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two." + raise MessageFormatError(err, str(data)) + + content: Optional[str] = None + for key in content_keys: + content = content if content is not None else data_message.get(key) + + # non-function call message must have content + if not is_tool_call and content is None: + err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys." + raise MessageFormatError(err, str(data)) + + if role not in allowed_roles: + raise UnrecognizedRoleError(role, allowed_roles) + + if data_message["role"] == "user": + assert content is not None + messages.append(UserMessage(content=content)) + elif data_message["role"] == "assistant": + tool_calls: Optional[List[ToolCall]] = None + + if is_tool_call: + tool_calls = _parse_tool_calls(data_message["tool_calls"]) + + weight = data_message.get("weight") + messages.append( + FinetuningAssistantMessage( + content=content, tool_calls=tool_calls, weight=weight + ) + ) + elif data_message["role"] == "system": + if system_prompt is not None: + err = "Multiple messages with role 'system' encountered. Only one is allowed." + raise MessageFormatError(err, str(data)) + + system_prompt = content + elif data_message["role"] == "tool": + assert content is not None + tool_message = _parse_tool_message(content, data_message) + messages.append(tool_message) + + # validate created messages + validator = MistralRequestValidatorV3(ValidationMode.finetuning) + validator.validate_messages(messages) + validator._validate_tools(available_tools or []) + + # whether to train only on last assistant message + only_last = data.get("only_last", False) or available_tools is not None + + return TrainingInstructSample( + messages=messages, + system_prompt=system_prompt, + available_tools=available_tools, + only_last=only_last, + ) + + +def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]: + available_tools = [] + for tool in tools: + if "function" not in tool: + raise FunctionFormatError( + "A tool dict does not have a 'function' key.", str(tool) + ) + + func_data = tool["function"] + + for key in ["name", "description", "parameters"]: + if key not in func_data: + raise FunctionFormatError( + f"A function dict does not have a {key} key.", str(func_data) + ) + + if not isinstance(func_data["parameters"], dict): + raise FunctionFormatError( + f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data) + ) + + description = func_data["description"] + function = Function( + name=func_data["name"], + description=description, + parameters=func_data["parameters"], + ) + + available_tools.append(Tool(function=function)) + return available_tools + + +def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]: + for key in ["id", "function"]: + if not all(key in call for call in calls): + err = f"A tool call of an assistant message does not have a {key} key" + raise ToolCallFormatError(err, str(calls)) + + for key in ["name", "arguments"]: + if not all(key in call["function"] for call in calls): + err = ( + f"A tool call function of an assistant message does not have a {key} key" + ) + raise ToolCallFormatError(err, str(calls)) + + if not all(isinstance(call["function"]["arguments"], str) for call in calls): + err = "A tool call function of an assistant message does not have a 'arguments' key of type str" + raise ToolCallFormatError(err, str(calls)) + + tool_calls = [ + ToolCall( + id=call["id"], + function=FunctionCall( + name=call["function"]["name"], + arguments=call["function"]["arguments"], + ), + ) + for call in calls + ] + return tool_calls + + +def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage: + if "tool_call_id" not in data_message: + err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'." + raise MessageFormatError(err, str(data_message)) + + call_id = data_message["tool_call_id"] + # name is deprecated in v3, but we'll add it nevertheless for now + name = data_message.get("name") + + return ToolMessage(content=content, tool_call_id=call_id, name=name) + + +def tokenize( + sample: Union[str, TrainingInstructSample], + instruct_tokenizer: InstructTokenizerBase, +) -> TokenSample: + if isinstance(sample, str): + tokenizer: Tokenizer = instruct_tokenizer.tokenizer + return tokenize_pretrain(sample, tokenizer) + elif isinstance(sample, TrainingInstructSample): + return tokenize_instruct(sample, instruct_tokenizer) + + raise ValueError( + f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}." + ) + + +def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample: + tokens = tokenizer.encode(sample, bos=True, eos=True) + masks = [True] * len(tokens) + return TokenSample(tokens, masks) + + +def tokenize_instruct( + sample: TrainingInstructSample, + instruct_tokenizer: InstructTokenizerBase, +) -> TokenSample: + tokens: List[int] = instruct_tokenizer.start() + masks: List[bool] = [False] + + mask_all_but_last = sample.only_last + + # find first and last user message + user_messages = [ + i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage) + ] + first_user_idx = user_messages[0] if user_messages else -1 + last_user_idx = user_messages[-1] if user_messages else -1 + + for msg_idx, message in enumerate(sample.messages): + if isinstance(message, UserMessage): + curr_tokens = instruct_tokenizer.encode_user_message( + message, + available_tools=sample.available_tools, + is_last=msg_idx == last_user_idx, + is_first=msg_idx == first_user_idx, + system_prompt=sample.system_prompt, + ) + curr_masks = [False] * len(curr_tokens) # only predict bot answers + elif isinstance(message, ToolMessage): + curr_tokens = instruct_tokenizer.encode_tool_message( + message, is_before_last_user_message=msg_idx < last_user_idx + ) + curr_masks = [False] * len(curr_tokens) # only predict bot answers + elif isinstance(message, FinetuningAssistantMessage): + is_last_message = msg_idx == (len(sample.messages) - 1) + + # we don't want to predict a random call id + message = maybe_remove_call_id(message, is_last_message=is_last_message) + + curr_tokens = instruct_tokenizer.encode_assistant_message( + message, is_before_last_user_message=False + ) + + is_weighted = message.weight is None or message.weight == 1 + is_relevant = (not mask_all_but_last) or is_last_message + if is_weighted and is_relevant: + curr_masks = [True] * len(curr_tokens) # only predict bot answers + else: + # in function calling we only backprop through last message + curr_masks = [False] * len(curr_tokens) + + tokens.extend(curr_tokens) + masks.extend(curr_masks) + + return TokenSample(tokens, masks) + + +def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool): + if message.tool_calls is None or not is_last_message: + return message + + # remove call id + message.tool_calls = [ + ToolCall(function=call.function) for call in message.tool_calls + ] + + return message diff --git a/finetune/distributed.py b/finetune/distributed.py new file mode 100644 index 0000000..5e9d2de --- /dev/null +++ b/finetune/distributed.py @@ -0,0 +1,59 @@ +import logging +import os +from functools import lru_cache +from typing import List, Union + +import torch +import torch.distributed as dist + +logger = logging.getLogger("distributed") + +BACKEND = "nccl" + + +@lru_cache() +def get_rank() -> int: + return dist.get_rank() + + +@lru_cache() +def get_world_size() -> int: + return dist.get_world_size() + + +def visible_devices() -> List[int]: + return [int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] + + +def set_device(): + logger.info(f"torch.cuda.device_count: {torch.cuda.device_count()}") + logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") + logger.info(f"local rank: {int(os.environ['LOCAL_RANK'])}") + + assert torch.cuda.is_available() + + assert len(visible_devices()) == torch.cuda.device_count() + + if torch.cuda.device_count() == 1: + # gpus-per-task set to 1 + torch.cuda.set_device(0) + return + + local_rank = int(os.environ["LOCAL_RANK"]) + logger.info(f"Set cuda device to {local_rank}") + + assert 0 <= local_rank < torch.cuda.device_count(), ( + local_rank, + torch.cuda.device_count(), + ) + torch.cuda.set_device(local_rank) + + +def avg_aggregate(metric: Union[float, int]) -> Union[float, int]: + buffer = torch.tensor([metric], dtype=torch.float32, device="cuda") + dist.all_reduce(buffer, op=dist.ReduceOp.SUM) + return buffer[0].item() / get_world_size() + + +def is_torchrun() -> bool: + return "TORCHELASTIC_RESTART_COUNT" in os.environ diff --git a/finetune/eval.py b/finetune/eval.py new file mode 100644 index 0000000..61ced41 --- /dev/null +++ b/finetune/eval.py @@ -0,0 +1,77 @@ +import logging +from typing import List + +import numpy as np +import torch.cuda +import torch.distributed as dist +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + +from .data.data_loader import Batch +from .distributed import get_rank, get_world_size +from .loss import compute_loss_with_mask +from .utils import TrainState + +logger = logging.getLogger("eval") + + +def main_logger_info(message: str) -> None: + if get_rank() == 0: + logger.info(message) + + +def evaluate( + model: FullyShardedDataParallel, + batches: List[Batch], + state: TrainState, +): + # Create fake samples to make FSDP happy for unbalanced data + num_samples = torch.tensor([len(batches)], device="cuda", dtype=torch.long) + all_num_samples = [torch.zeros_like(num_samples) for _ in range(get_world_size())] + + torch.distributed.all_gather(all_num_samples, num_samples) + + total_num_samples = int(torch.tensor(all_num_samples).sum().item()) + max_num_samples = int(torch.tensor(all_num_samples).max().item()) + + for _ in range(max_num_samples - int(num_samples.item())): + pad_x = np.zeros_like(batches[-1].x) + pad_y = np.zeros_like(batches[-1].y) + pad_sizes = batches[-1].sizes.copy() + + pad_batch = Batch(pad_x, pad_y, pad_sizes, is_pad_only=True) + batches.append(pad_batch) + + # eval mode! + model.eval() + + eval_loss = torch.tensor(0.0).cuda() + main_logger_info("Start eval...") + for batch in batches: + x = torch.from_numpy(batch.x).cuda() + y = torch.from_numpy(batch.y).cuda() + y_mask = ( + torch.from_numpy(batch.y_mask).cuda() if batch.y_mask is not None else None + ) + + with torch.no_grad(): + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + if y_mask.sum() > 0: + eval_loss += compute_loss_with_mask(output, y, y_mask) + + assert batch.is_pad_only or y.abs().sum() != 0, "Pad sample is used to compute loss." + + # sum loss + main_logger_info("Eval finished!") + + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + eval_loss /= total_num_samples + + state.this_eval_loss = eval_loss.item() + state.this_eval_perplexity = (2**eval_loss).item() + + # train mode! + model.train() diff --git a/finetune/loss.py b/finetune/loss.py new file mode 100644 index 0000000..d5c69b8 --- /dev/null +++ b/finetune/loss.py @@ -0,0 +1,16 @@ +from typing import Optional + +import torch +from torch.nn import functional as F + + +def compute_loss_with_mask( + logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor] +): + if target_mask is None: + return F.cross_entropy(logits, target, reduction="mean") + + mb_loss = F.cross_entropy(logits, target, reduction="none") + mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask) + + return mb_loss diff --git a/finetune/mixed_precision.py b/finetune/mixed_precision.py new file mode 100644 index 0000000..d017325 --- /dev/null +++ b/finetune/mixed_precision.py @@ -0,0 +1,47 @@ +from typing import Iterable + +import torch + + +def prepare_mixed_precision( + params: Iterable[torch.nn.Parameter], + param_dtype: torch.dtype, + optim_dtype: torch.dtype, +): + """Appends a freshly allocated fp32 tensor copy of all params to parameters that can be updated.""" + with torch.no_grad(): + for p in params: + if p.requires_grad: + # Mixed precision: let's save a fp32 param tensor to each params that require a grad + p._mp_param = torch.empty_like(p, dtype=optim_dtype) # type: ignore + p._mp_param.copy_(p.to(optim_dtype)) # type: ignore + + p.data = p.data.to(param_dtype) + + +def upcast_mixed_precision( + params: Iterable[torch.nn.Parameter], optim_dtype: torch.dtype +): + """Make sure to run this function BEFORE optimizer.step() so that all weights and optimizer states are updated in fp32 in .step()""" + with torch.no_grad(): + for p in params: + if p.requires_grad and p.grad is not None: + # store original tensor in p._temp + p._temp = p.data # type: ignore + # upcast data for the optimizer step + p.data = p._mp_param # type: ignore + p.grad = p.grad.to(optim_dtype) + + +def downcast_mixed_precision( + params: Iterable[torch.nn.Parameter], param_dtype: torch.dtype +): + """Make sure to run this function AFTER optimizer.step() as optimizer.step() will update data underlying p.data and p._mp_param pointers""" + with torch.no_grad(): + for p in params: + if p.requires_grad and p.grad is not None: + # copy fp32 weights into bfloat16 tensor + p._temp.copy_(p.data) # type: ignore + # set _temp again to the data tensor + p.data = p._temp # type: ignore + p.grad = p.grad.to(param_dtype) diff --git a/finetune/monitoring/__init__.py b/finetune/monitoring/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/monitoring/metrics_logger.py b/finetune/monitoring/metrics_logger.py new file mode 100644 index 0000000..97db6b6 --- /dev/null +++ b/finetune/monitoring/metrics_logger.py @@ -0,0 +1,226 @@ +import json +import logging +import os +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from torch.utils.tensorboard import SummaryWriter + +from finetune.args import MLFlowArgs, TrainArgs, WandbArgs +from finetune.utils import TrainState + +logger = logging.getLogger("metrics_logger") + +GB = 1024**3 + + +def get_train_logs( + state: TrainState, + loss: float, + lr: float, + peak_allocated_mem: float, + allocated_mem: float, + train_args: TrainArgs, +) -> Dict[str, Union[float, int]]: + metrics = { + "lr": lr, + "step": state.step, + "loss": loss, + "percent_done": 100 * state.step / train_args.max_steps, + "peak_allocated_mem": peak_allocated_mem / GB, + "allocated_mem": allocated_mem / GB, + "wps": state.wps, + "avg_wps": state.avg_wps, + "eta_in_seconds": state.eta, + } + + return metrics + + +def get_eval_logs( + step: int, + train_loss: float, + perplexity: Optional[float], + eval_loss: Optional[float], +) -> Dict[str, Union[float, int]]: + eval_dict = {"step": step, "train_loss": train_loss} + + if perplexity is not None: + eval_dict["perplexity"] = perplexity + + if eval_loss is not None: + eval_dict["eval_loss"] = eval_loss + return eval_dict + + +def train_log_msg( + state: TrainState, logs: Dict[str, Union[float, int]], loss: float +) -> str: + metrics: Dict[str, Union[float, int, datetime]] = dict(logs) # shallow copy + metrics.pop("eta_in_seconds") + + metrics["eta"] = datetime.now() + timedelta(seconds=state.eta) + metrics["step"] = state.step + metrics["loss"] = loss + + parts = [] + for key, fmt, new_name in [ + ("step", "06", None), + ("percent_done", "03.1f", "done (%)"), + ("loss", ".3f", None), + ("lr", ".1e", None), + ("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"), + ("allocated_mem", ".1f", "alloc_mem (GB)"), + ("wps", ".1f", "words_per_second"), + ("avg_wps", ".1f", "avg_words_per_second"), + ("eta", "%Y-%m-%d %H:%M:%S", "ETA"), + ]: + name = key if new_name is None else new_name + try: + parts.append(f"{name}: {metrics[key]:>{fmt}}") + except KeyError: + logger.error(f"{key} not found in {sorted(metrics.keys())}") + raise + + return " - ".join(parts) + + +def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str: + parts = [] + for key, fmt, new_name in [ + ("step", "06", None), + ("perplexity", ".3f", "eval_perplexity"), + ("eval_loss", ".3f", None), + ("train_loss", ".3f", None), + ]: + name = key if new_name is None else new_name + if key in logs: + parts.append(f"{name}: {logs[key]:>{fmt}}") + + return " - ".join(parts) + + +class MetricsLogger: + def __init__( + self, + dst_dir: Path, + tag: str, + is_master: bool, + wandb_args: WandbArgs, + mlflow_args: MLFlowArgs, + config: Optional[Dict[str, Any]] = None, + ): + self.dst_dir = dst_dir + self.tag = tag + self.is_master = is_master + self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl" + self.tb_dir = dst_dir / "tb" + self.summary_writer: Optional[SummaryWriter] = None + + if not self.is_master: + return + + filename_suffix = f".{tag}" + self.tb_dir.mkdir(exist_ok=True) + self.summary_writer = SummaryWriter( + log_dir=str(self.tb_dir), + max_queue=1000, + filename_suffix=filename_suffix, + ) + self.is_wandb = wandb_args.project is not None + self.is_mlflow = mlflow_args.tracking_uri is not None + + if self.is_wandb: + import wandb + + if wandb_args.key is not None: + wandb.login(key=wandb.key) + if wandb_args.offline: + os.environ["WANDB_MODE"] = "offline" + if wandb.run is None: + logger.info("initializing wandb") + wandb.init( + config=config, + dir=dst_dir, + project=wandb_args.project, + job_type="training", + name=wandb_args.run_name or dst_dir.name, + resume=False, + ) + + self.wandb_log = wandb.log + + if self.is_mlflow: + import mlflow + + mlflow.set_tracking_uri(mlflow_args.tracking_uri) + mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name) + + if tag == "train": + mlflow.start_run() + + self.mlflow_log = mlflow.log_metric + + def log(self, metrics: Dict[str, Union[float, int]], step: int): + if not self.is_master: + return + + metrics_to_ignore = {"step"} + assert self.summary_writer is not None + for key, value in metrics.items(): + if key in metrics_to_ignore: + continue + assert isinstance(value, (int, float)), (key, value) + self.summary_writer.add_scalar( + tag=f"{self.tag}.{key}", scalar_value=value, global_step=step + ) + + if self.is_mlflow: + self.mlflow_log(f"{self.tag}.{key}", value, step=step) + + if self.is_wandb: + # grouping in wandb is done with / + self.wandb_log( + { + f"{self.tag}/{key}": value + for key, value in metrics.items() + if key not in metrics_to_ignore + }, + step=step, + ) + + metrics_: Dict[str, Any] = dict(metrics) # shallow copy + if "step" in metrics_: + assert step == metrics_["step"] + else: + metrics_["step"] = step + metrics_["at"] = datetime.utcnow().isoformat() + with self.jsonl_path.open("a") as fp: + fp.write(f"{json.dumps(metrics_)}\n") + + def close(self): + if not self.is_master: + return + + if self.summary_writer is not None: + self.summary_writer.close() + self.summary_writer = None + + if self.is_wandb: + import wandb + + # to be sure we are not hanging while finishing + wandb.finish() + + if self.is_mlflow: + import mlflow + + mlflow.end_run() + + def __del__(self): + if self.summary_writer is not None: + raise RuntimeError( + "MetricsLogger not closed properly! You should " + "make sure the close() method is called!" + ) diff --git a/finetune/monitoring/utils.py b/finetune/monitoring/utils.py new file mode 100644 index 0000000..48161a8 --- /dev/null +++ b/finetune/monitoring/utils.py @@ -0,0 +1,34 @@ +import datetime +import logging +import sys +import time + + +class DeltaTimeFormatter(logging.Formatter): + def format(self, record): + delta = datetime.timedelta( + seconds=int(record.relativeCreated / 1000) + ) # no milliseconds + record.delta = delta + return super().format(record) + + +def set_logger(level: int = logging.INFO): + root = logging.getLogger() + root.handlers.clear() + root.setLevel(level) + tz, *_ = time.tzname + + LOGFORMAT = "%(asctime)s - %(delta)s - %(name)s - %(levelname)s - %(message)s" + TIMEFORMAT = f"%Y-%m-%d %H:%M:%S ({tz})" + formatter = DeltaTimeFormatter(LOGFORMAT, TIMEFORMAT) + + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(level) + handler.setFormatter(formatter) + root.addHandler(handler) + + handler = logging.StreamHandler(sys.stderr) + handler.setLevel(logging.WARNING) + handler.setFormatter(formatter) + root.addHandler(handler) diff --git a/finetune/utils.py b/finetune/utils.py new file mode 100644 index 0000000..95119a7 --- /dev/null +++ b/finetune/utils.py @@ -0,0 +1,83 @@ +import contextlib +import dataclasses +import datetime +import logging +import time +from typing import Optional, Protocol + +import torch + +logger = logging.getLogger("utils") + + +@dataclasses.dataclass +class TrainState: + max_steps: int + step: int = 0 + elapsed_time: float = 0.0 + n_seen_tokens: int = 0 + this_step_time: float = 0.0 + begin_step_time: float = 0.0 + this_eval_perplexity: Optional[float] = None + this_eval_loss: Optional[float] = None + + def start_step(self): + self.step += 1 + self.begin_step_time = time.time() + + def end_step(self, n_batch_tokens: int): + self.this_step_time = time.time() - self.begin_step_time + self.this_step_tokens = n_batch_tokens + + self.elapsed_time += self.this_step_time + self.n_seen_tokens += self.this_step_tokens + + self.begin_step_time = time.time() + + @property + def wps(self): + return self.this_step_tokens / self.this_step_time + + @property + def avg_wps(self): + return self.n_seen_tokens / self.elapsed_time + + @property + def eta(self): + steps_left = self.max_steps - self.step + avg_time_per_step = self.elapsed_time / self.step + + return steps_left * avg_time_per_step + + +def set_random_seed(seed: int) -> None: + """Set random seed for reproducibility.""" + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +class Closable(Protocol): + def close(self): + pass + + +@contextlib.contextmanager +def logged_closing(thing: Closable, name: str): + """ + Logging the closing to be sure something is not hanging at exit time + """ + try: + setattr(thing, "wrapped_by_closing", True) + yield + finally: + logger.info(f"Closing: {name}") + try: + thing.close() + except Exception: + logger.error(f"Error while closing {name}!") + raise + logger.info(f"Closed: {name}") + + +def now_as_str() -> str: + return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") diff --git a/finetune/wrapped_model.py b/finetune/wrapped_model.py new file mode 100644 index 0000000..6034b4c --- /dev/null +++ b/finetune/wrapped_model.py @@ -0,0 +1,227 @@ +import functools +import json +import logging +import math +from pathlib import Path +from typing import Callable, Union + +import safetensors +import torch +import torch.distributed.fsdp.wrap as torch_wrap +from torch.distributed.fsdp import BackwardPrefetch +from torch.distributed.fsdp.api import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + +from model.args import ModelArgs, MoeArgs +from model.transformer import Transformer, TransformerBlock + +from .args import LoraArgs +from .checkpointing import Checkpointer +from .distributed import ( + get_rank, + get_world_size, +) + +logger = logging.getLogger(__name__) + + +def main_logger_info(message: str) -> None: + if get_rank() == 0: + logger.info(message) + + +def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]: + """ + This function instantiates the FSDP wrap policy. + - Each Transformers block becomes it's own FSDP group so that only a single Transformer block is sharded at a time + - If LoRA is enabled, we additionally create seperate FSDP sub-groups for every trainable and non-trainable parameter group + since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html + """ + + # Each transformer block becomes a FSDP group, each being sharded seperately + transformer_block_wrap_policy = functools.partial( + torch_wrap.transformer_auto_wrap_policy, + transformer_layer_cls=(TransformerBlock,), + ) + + if not is_lora: + return transformer_block_wrap_policy + + def fsdp_lora_policy_fn(module): + return all(p.requires_grad for p in module.parameters()) + + # For LoRA training, trainable and non-trainable parameters need to be put into + # different FSDP groups + fsdp_lora_policy = functools.partial( + torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn + ) + + policies = [fsdp_lora_policy, transformer_block_wrap_policy] + + return functools.partial(torch_wrap._or_policy, policies=policies) + + +def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]): + world_size = get_world_size() + + num_params = world_size * sum(p.numel() for p in model.parameters()) + num_train_params = world_size * sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + main_logger_info( + f"{num_train_params:,.0f} out of {num_params:,.0f} parameter are finetuned ({num_train_params / num_params * 100:.2f}%)." + ) + + +def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype): + """ + Initialize LoRA layers with Kaiming uniform and zeros. + See original paper for more info: https://arxiv.org/abs/2106.09685 and + original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122 + """ + for m_name, module in model.named_modules(): + if all(p.is_meta for p in module.parameters()): + for p_name, param in module.named_parameters(): + module._parameters[p_name] = torch.nn.Parameter( + torch.empty_like(param, device="cpu", dtype=param_dtype) + ) + param = module._parameters[p_name] + + if m_name.split(".")[-1] == "lora_A": + torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5)) + elif m_name.split(".")[-1] == "lora_B": + torch.nn.init.zeros_(param) + else: + raise ValueError( + "Only Lora layers should be randomely initialized." + ) + + +def load_model( + folder: Path, + lora: LoraArgs, + checkpoint: bool, + param_dtype: torch.dtype, +) -> FullyShardedDataParallel: + with open(folder / "params.json", "r") as f: + args = json.loads(f.read()) + + model_args = ModelArgs( + lora=lora, + dim=args["dim"], + n_layers=args["n_layers"], + head_dim=args["head_dim"], + hidden_dim=args["hidden_dim"], + n_heads=args["n_heads"], + n_kv_heads=args["n_kv_heads"], + norm_eps=args["norm_eps"], + vocab_size=args["vocab_size"], + ) + + if model_args.vocab_size == 32000: + raise ValueError( + f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`." + ) + + assert ( + model_args.vocab_size >= 32768 + ), "Make sure to use a model with a vocab size of at least 32768" + + if args.get("rope_theta") is not None: + model_args.rope_theta = args["rope_theta"] + + if args.get("moe") is not None: + model_args.moe = MoeArgs(**args["moe"]) + + with torch.device("meta"): + model = Transformer(args=model_args, checkpoint=checkpoint) + + if get_rank() == 0: + state_dict = load_state_dict(folder, dtype=param_dtype) + + model.load_state_dict(state_dict, assign=True) # type: ignore + logger.info("Loaded model on cpu!") + + if lora.enable: + logger.info("Initializing lora layers ...") + # initialize LoRA layers + initialize_lora_parameters(model, param_dtype) + + assert not any( + p.is_meta for p in model.parameters() + ), "All parameters should be intialized by now" + assert all( + p.dtype == param_dtype for p in model.parameters() + ), f"All parameters should be on {param_dtype}" + + logger.info("Finished initialization!") + param_init_fn = None + else: + + def param_init_fn(m): + m.to_empty(device=torch.cuda.current_device(), recurse=False) + m.to(param_dtype) + + assert all( + p.is_meta for p in model.parameters() + ), "All parameters should be on meta" + + torch.distributed.barrier() + + # only finetune LoRA parameters and freeze before wrapping + if lora.enable: + for name, param in model.named_parameters(): + if "lora" in name: + param.requires_grad = True + else: + param.requires_grad = False + + auto_wrap_policy = get_fsdp_policy(model_args.lora.enable) + + main_logger_info(f"Sharding model over {get_world_size()} GPUs ...") + + wrapped_model = FullyShardedDataParallel( + model, + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + device_id=torch.cuda.current_device(), + sync_module_states=True, + param_init_fn=param_init_fn, + ) + main_logger_info("Model sharded!") + + log_train_params(wrapped_model) + + return wrapped_model + + +@torch.no_grad() +def load_state_dict(path: Path, dtype: torch.dtype): + assert path.is_dir(), path + + this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True) + this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False) + + assert ( + this_safetensors_path.exists() or this_torch_path.exists() + ), f"Either {this_safetensors_path} or {this_torch_path} must exist." + assert not ( + this_safetensors_path.exists() and this_torch_path.exists() + ), f"Only one of {this_safetensors_path} or {this_torch_path} should exist." + + if this_safetensors_path.exists(): + logger.info(f"Reloading model from {this_safetensors_path} ...") + model_state_dict = safetensors.torch.load_file(this_safetensors_path) + else: + logger.info(f"Reloading model from {this_torch_path} ...") + model_state_dict = torch.load(this_torch_path) + + logger.info(f"Converting model to dtype {dtype} ...") + + for k, v in model_state_dict.items(): + model_state_dict[k] = v.to(dtype) + + return model_state_dict diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/args.py b/model/args.py new file mode 100644 index 0000000..4b7fbbb --- /dev/null +++ b/model/args.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field +from typing import Optional + +from simple_parsing.helpers import Serializable + + +@dataclass +class LoraArgs(Serializable): + enable: bool = True + rank: int = 16 + dropout: float = 0.0 + scaling: float = 2.0 + + def __post_init__(self): + if self.enable: + assert self.rank > 0 + assert self.scaling > 0.0 + + +@dataclass +class MoeArgs(Serializable): + num_experts: int = 8 + num_experts_per_tok: int = 2 + + +@dataclass +class ModelArgs(Serializable): + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + rope_theta: float = 10000.0 + + lora: LoraArgs = field(default_factory=LoraArgs) + moe: Optional[MoeArgs] = None diff --git a/model/lora.py b/model/lora.py new file mode 100644 index 0000000..d376df6 --- /dev/null +++ b/model/lora.py @@ -0,0 +1,96 @@ +from typing import NamedTuple + +import torch +import torch.nn as nn + + +class LoRALinear(nn.Module): + """ + Implementation of: + - LoRA: https://arxiv.org/abs/2106.09685 + + Notes: + - Freezing is handled at network level, not layer level. + - Scaling factor controls relative importance of LoRA skip + connection versus original frozen weight. General guidance is + to keep it to 2.0 and sweep over learning rate when changing + the rank. + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int, + scaling: float, + dropout: float, + bias: bool = False, + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + assert not bias + self.bias = bias + self.rank = rank + self.scaling = scaling + + self.dropout = nn.Dropout(p=dropout) + + self.lora_A = nn.Linear( + self.in_features, + self.rank, + bias=self.bias, + ) + self.lora_B = nn.Linear( + self.rank, + self.out_features, + bias=self.bias, + ) + + self.frozen_W = nn.Linear(self.in_features, self.out_features, bias=self.bias) + + # make sure no LoRA weights are marked as "missing" in load_state_dict + def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple): + # empty missing keys in place + incompatible_keys.missing_keys[:] = [] # type: ignore + + self.register_load_state_dict_post_hook(ignore_missing_keys) + + def merge_weight(self): + with torch.no_grad(): + down_weight = self.lora_A.weight + up_weight = self.lora_B.weight + + weight = up_weight.mm(down_weight) * self.scaling + + weight += self.frozen_W.weight + return weight + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + key_name = prefix + "weight" + + # full checkpoint + if key_name in state_dict: + w_ref = state_dict[key_name] + + # load frozen weights + self.frozen_W.load_state_dict({"weight": w_ref}, assign=True) + + def forward(self, x: torch.Tensor): + lora = self.lora_B(self.lora_A(self.dropout(x))) + return self.frozen_W(x) + lora * self.scaling + + def __repr__(self) -> str: + return "{}Linear(in_features={}, out_features={}, r={}, dropout={})".format( + "LoRA", self.in_features, self.out_features, self.rank, self.dropout.p + ) diff --git a/model/moe.py b/model/moe.py new file mode 100644 index 0000000..457933d --- /dev/null +++ b/model/moe.py @@ -0,0 +1,31 @@ +from typing import List + +import torch +import torch.nn as nn + +from .args import MoeArgs + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.args = moe_args + + def forward(self, inputs: torch.Tensor): + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk( + gate_logits, self.args.num_experts_per_tok + ) + weights = torch.nn.functional.softmax(weights, dim=1, dtype=torch.float).to( + inputs.dtype + ) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( + inputs[batch_idx] + ) + return results diff --git a/model/rope.py b/model/rope.py new file mode 100644 index 0000000..2d693e2 --- /dev/null +++ b/model/rope.py @@ -0,0 +1,27 @@ +from typing import Optional, Tuple + +import torch + + +def precompute_freqs_cis( + dim: int, end: int, theta: float, device: Optional[torch.device] = None +) -> torch.Tensor: + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + ) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[:, None, :] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/model/transformer.py b/model/transformer.py new file mode 100644 index 0000000..82d2f1c --- /dev/null +++ b/model/transformer.py @@ -0,0 +1,230 @@ +import operator +from functools import partial, reduce +from typing import Iterable, List, Optional + +import torch +import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt +import torch.nn as nn +from xformers.ops.fmha import memory_efficient_attention +from xformers.ops.fmha.attn_bias import AttentionBias, BlockDiagonalCausalMask + +from .args import ModelArgs +from .lora import LoRALinear +from .moe import MoeLayer +from .rope import apply_rotary_emb, precompute_freqs_cis + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int): + keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) + values = torch.repeat_interleave(values, repeats=repeats, dim=dim) + return keys, values + + +def maybe_lora_layer( + args: ModelArgs, rank: Optional[int] = None +) -> partial[LoRALinear] | type[nn.Linear]: + MaybeLora: partial[LoRALinear] | type[nn.Linear] + if not args.lora.enable: + return nn.Linear + + rank = rank or args.lora.rank + scaling = args.lora.scaling + dropout = args.lora.dropout + + MaybeLora = partial( + LoRALinear, + rank=rank, + scaling=scaling, + dropout=dropout, + ) + + return MaybeLora + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + self.head_dim: int = args.head_dim + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + MaybeLora = maybe_lora_layer(args) + + self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) + + self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: AttentionBias, + ) -> torch.Tensor: + seqlen_sum, _ = x.shape + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim) + xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) + xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + key, val = xk, xv + + # Repeat keys and values to match number of query heads + key, val = repeat_kv(key, val, self.repeats, dim=1) + + # xformers requires (B=1, S, H, D) + xq, key, val = xq[None, ...], key[None, ...], val[None, ...] + output = memory_efficient_attention(xq, key, val, mask) + + return self.wo(output.view(seqlen_sum, -1)) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + MaybeLora = maybe_lora_layer(args) + self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False) + self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False) + self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False) + + def forward(self, x) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + + self.feed_forward: MoeLayer | FeedForward + if args.moe is not None: + self.feed_forward = MoeLayer( + experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], + gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), + moe_args=args.moe, + ) + else: + self.feed_forward = FeedForward(args=args) + + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + att_mask: AttentionBias, + ) -> torch.Tensor: + r = self.attention(self.attention_norm(x), freqs_cis, att_mask) + h = x + r + + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + + return out + + +class Transformer(nn.Module): + def __init__(self, args: ModelArgs, checkpoint: bool = False): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + assert self.vocab_size > 0 + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + self.layers = torch.nn.ModuleList() + for _ in range(args.n_layers): + block: torch.nn.Module = TransformerBlock(args=args) + if checkpoint: + # activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html + non_reentrant_wrapper = partial( + torch_ckpt.checkpoint_wrapper, + checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT, + ) + block = non_reentrant_wrapper(block) + + self.layers.append(block) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = torch.nn.Linear( + args.dim, + args.vocab_size, + bias=False, + ) + + # set lazily + self._freqs_cis = None + + @property + def dtype(self) -> torch.dtype: + return self.tok_embeddings.weight.dtype + + @property + def device(self) -> torch.device: + return self.tok_embeddings.weight.device + + @property + def freqs_cis(self): + # lazy init + device = next(iter(self.parameters())).device + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis( + self.args.head_dim, 128_000, theta=self.args.rope_theta, device=device + ) + + return self._freqs_cis + + def forward( + self, + input_ids: torch.Tensor, + seqlens: List[int], + ) -> torch.Tensor: + assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0]) + + h = self.tok_embeddings(input_ids) + positions = positions_from_sizes(seqlens, self.freqs_cis.device) + att_mask = BlockDiagonalCausalMask.from_seqlens(seqlens) + + freqs_cis = self.freqs_cis[positions].to(device=h.device) + + for layer in self.layers: + h = layer(h, freqs_cis, att_mask) + + return self.output(self.norm(h)).float() + + +def positions_from_sizes(sizes: Iterable[int], device): + return torch.tensor( + reduce(operator.iadd, [list(range(s)) for s in sizes], []), + dtype=torch.long, + device=device, + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c6f223d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[tool.black] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' +line-length = 88 + +[tool.isort] +known_third_party = [] +line_length = 88 + +[tool.mypy] +python_version = "3.10" +ignore_missing_imports = true +follow_imports = "normal" + +[tool.ruff] +lint.select = ["E", "F", "W", "C90", "I", "C4", "PT", "RSE", "TID", "TCH", "FLY", "NPY", "RUF", "T10"] +line-length = 88 +lint.ignore = ["E501"] +target-version = "py38" +exclude = [".git","__pycache__",".vscode",".mypy_cache"] + +[tool.ruff.lint.mccabe] +max-complexity = 20 + +[tool.pytest.ini_options] +python_classes = [ + "!TestConfig", +] +filterwarnings = [ + "error", + "ignore::DeprecationWarning", + "ignore::ImportWarning", + "ignore::DeprecationWarning:pkg_resources", + "ignore::DeprecationWarning:wandb", + "ignore::DeprecationWarning:torch.utils.tensorboard", + "ignore::UserWarning", +] +testpaths = [] +addopts = "--ignore=third_party --ignore=tokens" diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000..d1bcbd3 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,5 @@ +mypy==1.4.0 +mypy-extensions==1.0.0 +pre-commit +ruff==0.0.282 +pytest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a5cdf88 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +fire +simple-parsing +pyyaml +mistral-common>=1.1.0 +safetensors +tensorboard +tqdm + +torch==2.2 +triton==2.2 +xformers==0.0.24 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures b/tests/fixtures new file mode 120000 index 0000000..8b13789 --- /dev/null +++ b/tests/fixtures @@ -0,0 +1 @@ + diff --git a/tests/test_checkpointer.py b/tests/test_checkpointer.py new file mode 100644 index 0000000..f8eef34 --- /dev/null +++ b/tests/test_checkpointer.py @@ -0,0 +1,212 @@ +from pathlib import Path + +import pytest +import torch + +from finetune.args import LoraArgs +from finetune.checkpointing import Checkpointer +from finetune.utils import TrainState +from finetune.wrapped_model import load_model +from tests.test_utils import MODEL_PATH, is_float_equal, setup_mp_test_dist +from utils.merge_lora import merge_checkpoints + +from .test_utils import spawn_for_all_world_sizes + +# fmt: off +EXPECTED_NON_LORA_KEYS = sorted(['layers.0.attention.wk.weight', 'layers.0.attention.wo.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wv.weight', 'layers.0.attention_norm.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wo.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wv.weight', 'layers.1.attention_norm.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.ffn_norm.weight', 'norm.weight', 'output.weight', 'tok_embeddings.weight']) +EXPECTED_LORA_KEYS = sorted(['layers.0.attention.wq.lora_A.weight', 'layers.0.attention.wq.lora_B.weight', 'layers.0.attention.wk.lora_A.weight', 'layers.0.attention.wk.lora_B.weight', 'layers.0.attention.wv.lora_A.weight', 'layers.0.attention.wv.lora_B.weight', 'layers.0.attention.wo.lora_A.weight', 'layers.0.attention.wo.lora_B.weight', 'layers.0.feed_forward.w1.lora_A.weight', 'layers.0.feed_forward.w1.lora_B.weight', 'layers.0.feed_forward.w2.lora_A.weight', 'layers.0.feed_forward.w2.lora_B.weight', 'layers.0.feed_forward.w3.lora_A.weight', 'layers.0.feed_forward.w3.lora_B.weight', 'layers.1.attention.wq.lora_A.weight', 'layers.1.attention.wq.lora_B.weight', 'layers.1.attention.wk.lora_A.weight', 'layers.1.attention.wk.lora_B.weight', 'layers.1.attention.wv.lora_A.weight', 'layers.1.attention.wv.lora_B.weight', 'layers.1.attention.wo.lora_A.weight', 'layers.1.attention.wo.lora_B.weight', 'layers.1.feed_forward.w1.lora_A.weight', 'layers.1.feed_forward.w1.lora_B.weight', 'layers.1.feed_forward.w2.lora_A.weight', 'layers.1.feed_forward.w2.lora_B.weight', 'layers.1.feed_forward.w3.lora_A.weight', 'layers.1.feed_forward.w3.lora_B.weight']) +# fmt: on + + +@pytest.mark.parametrize( + ("world_size", "save_only_lora", "enable_lora"), + [ + (1, False, False), + (2, False, False), + (1, False, True), + (2, False, True), + (1, True, True), + (2, True, True), # this is the most important test! - FSDP only LORA + ], +) +def test_states_retrieval(world_size, enable_lora, save_only_lora): + spawn_for_all_world_sizes( + _check_states_retrieval, + world_sizes=[world_size], + args=[enable_lora, save_only_lora], + deterministic=True, + ) + + +def _check_states_retrieval( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, + enable_lora: bool, + save_only_lora: bool, +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=enable_lora), + checkpoint=True, + param_dtype=torch.bfloat16, + ) + + # mock a train state that has done three steps + step = 3 + state = TrainState(max_steps=10, step=step) # 10 is just a dummy value here + + # mock run_dir as we won't save anything in this test + run_dir = Path("dir") + use_sf = True + + checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) + prefix = "lora" if enable_lora else "consolidated" + + assert checkpointer.dst_dir == Path( + f"dir/checkpoints/checkpoint_00000{step}/consolidated" + ), checkpointer.dst_dir + assert checkpointer.consolidated_path( + checkpointer.dst_dir, use_sf, save_only_lora=enable_lora + ) == Path( + f"dir/checkpoints/checkpoint_00000{step}/consolidated/{prefix}.safetensors" + ), checkpointer.consolidated_path( + checkpointer.dst_dir, use_sf, save_only_lora=enable_lora + ) + + # increase step by one + state.start_step() + + assert checkpointer.dst_dir == Path( + f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated" + ), checkpointer.dst_dir + assert checkpointer.consolidated_path( + checkpointer.dst_dir, use_sf, save_only_lora=enable_lora + ) == Path( + f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated/{prefix}.safetensors" + ), checkpointer.consolidated_path( + checkpointer.dst_dir, use_sf, save_only_lora=enable_lora + ) + + assert all("lora" in k for k in EXPECTED_LORA_KEYS), EXPECTED_LORA_KEYS + + for save_dtype in [torch.float16, torch.bfloat16, torch.float32]: + + save_dict = checkpointer.retrieve_save_states( + save_only_lora=save_only_lora, save_dtype=save_dtype + ) + + for k, v in save_dict.items(): + assert v.dtype == save_dtype, f"{k}: v.dtype" + + if save_only_lora: + assert sorted(save_dict.keys()) == EXPECTED_LORA_KEYS, save_dict.keys() + else: + assert sorted(save_dict.keys()) == EXPECTED_NON_LORA_KEYS, save_dict.keys() + + EXPECTED_NON_LORA_VALUES = 34909.7500 + + EXPECTED_LORA_VALUES = 984.4179840087891 + + values_sum = sum(v.abs().float().sum().item() for v in save_dict.values()) + + if save_only_lora: + assert is_float_equal( + values_sum, EXPECTED_LORA_VALUES, 5e-1 + ), f"{values_sum} for {save_dtype}" + else: + assert is_float_equal( + values_sum, EXPECTED_NON_LORA_VALUES, 1e-1 + ), f"{values_sum} for {save_dtype}" + + +@pytest.mark.parametrize("world_size", [1, 2]) +def test_lora_merge_equal(world_size): + spawn_for_all_world_sizes( + _check_lora_merge_equal, + world_sizes=[world_size], + deterministic=True, + ) + + +def _check_lora_merge_equal( + rank: int, world_size: int, filename: str, filename_rpc: str +): + model_parallel = 1 + enable_lora = True + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + + world_size // model_parallel + + folder = Path(MODEL_PATH) + + step = 3 + state = TrainState(max_steps=10, step=step) # 10 is just a dummy value here + run_dir = Path("dir") + + non_lora_model = load_model( + folder=folder, + lora=LoraArgs(enable=False), + checkpoint=True, + param_dtype=torch.bfloat16, + ) + + non_lora_checkpointer = Checkpointer( + non_lora_model, state, run_dir=run_dir, num_ckpt_keep=None + ) + orig_model = non_lora_checkpointer.retrieve_save_states( + save_only_lora=False, save_dtype=torch.float32 + ) + + scaling = 2.0 + + model = load_model( + folder=folder, + lora=LoraArgs(enable=enable_lora, scaling=scaling), + checkpoint=True, + param_dtype=torch.bfloat16, + ) + + state_dict = model.state_dict() + state_dict = {k: v + 0.01 if "lora" in k else v for k, v in state_dict.items()} + model.load_state_dict(state_dict) + + # mock a train state that has done three steps + checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) + + for save_dtype in [torch.float16, torch.bfloat16, torch.float32]: + model_dict = { + k: torch.empty_like(v).copy_(v).to(save_dtype) + for k, v in orig_model.items() + } + merged_save_dict = checkpointer.retrieve_save_states( + save_only_lora=False, save_dtype=save_dtype + ) + + lora_save_dict = checkpointer.retrieve_save_states( + save_only_lora=True, save_dtype=save_dtype + ) + + merge_checkpoints( + model_dict, lora_save_dict, scaling=scaling, save_dtype=save_dtype + ) + + for k in model_dict.keys(): + torch.allclose( + model_dict[k].cpu(), merged_save_dict[k].cpu(), atol=1e-3, rtol=1e-3 + ) + + for k in model_dict.keys(): + # make sure that merged model differs from orig model + if "attention" in k or "feed_forward" in k: + not torch.allclose( + orig_model[k].to(save_dtype).cpu(), + merged_save_dict[k].cpu(), + atol=1e-3, + rtol=1e-3, + ) diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..e31b872 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,729 @@ +import copy +import json +from pathlib import Path + +import numpy as np +import pytest +from mistral_common.protocol.instruct.messages import FinetuningAssistantMessage +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from finetune.data.args import DataArgs, InstructArgs +from finetune.data.data_loader import build_data_loader +from finetune.data.dataset import ( + DataFile, + SampleType, + get_dataset_iterator, + get_rng, + lazy_load_and_yield, + maybe_chunk_lines, + parse_data_sources, + preload_and_yield, +) +from finetune.data.tokenize import build_instruct_sample, encode + +from .test_utils import spawn_for_all_world_sizes + +# fmt: off +EXPECTED_X = [ + [ + # for pretrain + [ + # for DP=1 + [2051851, 1961139, 2000184, 2081307, 2341123, 1225437, 1739008, 724695, 570810, 632094] + ], + [ + # for DP=2 + [2020745, 1938377, 2244286, 2042079, 1824023], + [2103241, 2032118, 1868430, 1093072, 770996], + ] + ], + [ + # for instruct + [ + # for DP=1 + [1379941, 1438894, 965536, 1019713, 889921, 999322, 1647173, 941080, 1281597, 1584884] + ], + [ + # for DP=2 + [1379941, 1438894, 889899, 1005451, 876854], + [1034325, 999322, 982295, 941080, 725946], + ] + ] +] +EXPECTED_Y = [ + [ + # for pretrain + [ + # for DP=1 + [2081367, 1961098, 1970714, 2110856, 2334822, 1251057, 1745267, 699854, 571600, 660015] + ], + [ + # for DP=2 + [2021840, 1966833, 2223275, 2063077, 1824011], + [2132793, 2002569, 1870876, 1122569, 757126], + ] + ], + [ + # for instruct + [ + # for DP=1 + [1409448, 1430886, 937609, 1019339, 889921, 970976, 1660330, 942631, 1308399, 1583658] + ], + [ + # for DP=2 + [1409448, 1430886, 895531, 990091, 863522], + [1041462, 970976, 991091, 942631, 737311] + ] + ] +] +EXPECTED_MASKS = [ + [ + # for pretrain + [ + # for DP=1 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ], + [ + # for DP=2 + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ], + [ + # for instruct + [ + # for DP=1 + [47, 0, 34, 0, 0, 82, 0, 0, 0, 0] + ], + [ + # for DP=2 + [47, 0, 4, 0, 0], + [19, 82, 0, 0, 23], + ] + ] +] + +EXPECTED_X_FUNC = [ + [ + # for DP=1 + [1005531, 1551735, 1261711, 1531024, 1280259, 1069883, 858107, 1021583, 1203265, 1242999], + ], + [ + # for DP=2 + [985281, 1217766, 1442139, 1533790, 1253607], + [1005531, 1551735, 1261711, 1531024, 1280259], + ] +] + +EXPECTED_Y_FUNC = [ + [ + # for DP=1 + [977126, 1580120, 1233326, 1559463, 1280241, 1042456, 879031, 994127, 1196263, 1270581], + ], + [ + # for DP=2 + [957934, 1218899, 1441783, 1533011, 1224541], + [977126, 1580120, 1233326, 1559463, 1280241], + ] +] + +EXPECTED_MASKS_FUNC = [ + [ + # for DP=1 + [91, 0, 0, 0, 0, 77, 0, 0, 53, 0], + ], + [ + # for DP=2 + [16, 47, 0, 86, 98], + [91, 0, 0, 0, 0], + ] +] +# fmt: on + + +class MockTokenizer: + def encode(self, content: str, *args, **kwargs) -> str: + return content + + +class MockInstructTokenizerBaseBase: + def __init__(self): + self.tokenizer = MockTokenizer() + + def encode_user_message(self, message, *args, **kwargs): + return message.content + + def encode_assistant_message(self, message, *args, **kwargs): + return message.content + + def start(self): + return [] + + +def stringify(samples): + lines = [] + for sample in samples: + string_list = sample.tokens + lines.append("".join(string_list)) + + return lines + + +@pytest.mark.parametrize( + ("world_size", "model_parallel", "is_instruct"), + [ + (1, 1, False), + (2, 1, False), + (2, 2, False), + (1, 1, True), + (2, 1, True), + (2, 2, True), + ], +) +def test_data_loader_dist(world_size, model_parallel, is_instruct): + spawn_for_all_world_sizes( + _check_data_loader_dist, + world_sizes=[world_size], + args=[model_parallel, is_instruct], + deterministic=True, + ) + + +def _check_data_loader_dist( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, + model_parallel: int, + is_instruct: bool, +): + dp_world_size = world_size // model_parallel + dp_rank = rank // model_parallel + + seed = 0 + seq_len = 100 + batch_size = 1 + + instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) + + if is_instruct: + # at the moment we have to pass some instuction finetuning + data_args = DataArgs( + data="", + instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", + instruct=instruct, + ) + else: + data_args = DataArgs( + data="tests/fixtures/sample_pretrain_1.jsonl:1.0,tests/fixtures/sample_pretrain_2.jsonl:1.0", + instruct_data="tests/fixtures/sample_instruct.jsonl:.01", + instruct=instruct, + ) + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + data_loader = build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=seed, + rank=dp_rank, + world_size=dp_world_size, + is_eval=False, + ) + + x_sums = [] + y_sums = [] + masks = [] + + num_samples = 10 // dp_world_size + + for _ in range(num_samples): + batch = next(data_loader) + x_sums.append(batch.x.sum()) + y_sums.append(batch.y.sum()) + mask_sum = batch.y_mask.sum() if batch.y_mask is not None else 0 + masks.append(mask_sum) + + expected_x_sums = EXPECTED_X[is_instruct][dp_world_size - 1][dp_rank] + expected_y_sums = EXPECTED_Y[is_instruct][dp_world_size - 1][dp_rank] + expected_masks = EXPECTED_MASKS[is_instruct][dp_world_size - 1][dp_rank] + + print(f"rank: {rank}, world_size: {world_size}, x: {x_sums}") + print(f"rank: {rank}, world_size: {world_size}, y: {y_sums}") + print(f"rank: {rank}, world_size: {world_size}, x shape: {masks}") + + assert x_sums == expected_x_sums, x_sums + assert y_sums == expected_y_sums, y_sums + assert masks == expected_masks, masks + + +@pytest.mark.parametrize("world_size", [1, 2]) +def test_data_loader_dist_fn_call(world_size): + spawn_for_all_world_sizes( + _check_data_loader_dist_fn_call, + world_sizes=[world_size], + deterministic=True, + ) + + +def _check_data_loader_dist_fn_call( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, +): + dp_world_size = world_size + dp_rank = rank + + seed = 0 + seq_len = 100 + batch_size = 1 + + data_args = DataArgs( + data="", + instruct_data="tests/fixtures/sample_instruct_fn_call_short.jsonl:.3", + instruct=InstructArgs(shuffle=True, dynamic_chunk_fn_call=True), + ) + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + data_loader = build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=seed, + rank=dp_rank, + world_size=dp_world_size, + is_eval=False, + ) + + x_sums = [] + y_sums = [] + masks = [] + + num_samples = 10 // dp_world_size + + for _ in range(num_samples): + batch = next(data_loader) + x_sums.append(batch.x.sum()) + y_sums.append(batch.y.sum()) + mask_sum = batch.y_mask.sum() if batch.y_mask is not None else 0 + masks.append(mask_sum) + + expected_x_sums = EXPECTED_X_FUNC[dp_world_size - 1][dp_rank] + expected_y_sums = EXPECTED_Y_FUNC[dp_world_size - 1][dp_rank] + expected_masks = EXPECTED_MASKS_FUNC[dp_world_size - 1][dp_rank] + + assert x_sums == expected_x_sums, x_sums + assert y_sums == expected_y_sums, y_sums + assert masks == expected_masks, masks + + +def test_data_loader_equal_fsdp(): + spawn_for_all_world_sizes( + _check_data_loader_equal_fsdp, + world_sizes=[2], + deterministic=True, + ) + + +def _check_data_loader_equal_fsdp( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, +): + model_parallel = 2 + world_size // model_parallel + rank // model_parallel + + seed = 0 + seq_len = 100 + batch_size = 1 + + instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) + + data_args = DataArgs( + data="", + instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", + instruct=instruct, + ) + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + data_loader_0 = build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=seed, + rank=0, + world_size=world_size, + is_eval=False, + ) + + data_loader_1 = build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=seed, + rank=1, + world_size=world_size, + is_eval=False, + ) + + x_sums = [] + y_sums = [] + + num_samples = 10 // 2 # run 5 * 2 training steps + + for _ in range(num_samples): + batch = next(data_loader_0) + x_sums.append(batch.x.sum()) + y_sums.append(batch.y.sum()) + + batch = next(data_loader_1) + x_sums.append(batch.x.sum()) + y_sums.append(batch.y.sum()) + + # check that mp can match ddp for both ranks + expected_x_sums = [ + y for x in zip(EXPECTED_X[1][1][0], EXPECTED_X[1][1][1]) for y in x + ] + expected_y_sums = [ + y for x in zip(EXPECTED_Y[1][1][0], EXPECTED_Y[1][1][1]) for y in x + ] + + assert x_sums == expected_x_sums, x_sums + assert y_sums == expected_y_sums, y_sums + + +def test_dynamic_fn_call_chunk(): + jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_short.jsonl") + + non_chunked_samples = [] + with jsonl_file.open() as file_handle: + for line in file_handle: + non_chunked_samples.append(build_instruct_sample(json.loads(line))) + + num_expected_chunks = 0 + for sample in non_chunked_samples: + if sample.only_last: + num_expected_chunks += ( + sum(isinstance(m, FinetuningAssistantMessage) for m in sample.messages) + - 1 + ) + + chunked_samples = [] + with jsonl_file.open() as file_handle: + lines = file_handle.readlines() + extra_lines = maybe_chunk_lines(lines) + + for line in extra_lines: + chunked_samples.append(build_instruct_sample(json.loads(line))) + + assert num_expected_chunks == len(chunked_samples) + + +def test_dynamic_fn_call_chunk_integration(): + jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_multi.jsonl") + + multi_samples = [] + with jsonl_file.open() as file_handle: + for line in file_handle: + multi_samples.append(build_instruct_sample(json.loads(line))) + + jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_single.jsonl") + + chunked_samples = [] + with jsonl_file.open() as file_handle: + for line in file_handle: + chunked_samples.append(build_instruct_sample(json.loads(line))) + + with jsonl_file.open() as file_handle: + lines = file_handle.readlines() + extra_lines = maybe_chunk_lines(lines) + + for line in extra_lines: + chunked_samples.append(build_instruct_sample(json.loads(line))) + + assert list(reversed(multi_samples)) == chunked_samples + + +def test_fn_call(): + batch_size = 1 + data_args = DataArgs( + data="", + instruct_data="", + eval_instruct_data="tests/fixtures/sample_instruct_fn_call.jsonl", + ) + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + seq_len = 10000 + + data_loader = build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=None, + rank=0, + world_size=1, + is_eval=True, + ) + + all_loss_strings = [] + for batch in data_loader: + y_mask = ( + np.asarray(batch.y_mask, int) + if batch.y_mask is not None + else np.ones_like(batch.x) + ) + + start_index = end_index = 0 + for size in batch.sizes: + end_index += size + tokens = batch.y[start_index:end_index] + mask = y_mask[start_index:end_index] + + tokens_for_loss = [int(y) for i, y in enumerate(tokens) if mask[i] == 1] + start_index += size + + decoded = instruct_tokenizer.tokenizer.decode(tokens_for_loss) + if len(decoded) > 0: + all_loss_strings.append(decoded) + + # Verify that the loss is always only computed over the + expected_loss_strings = [] + with open(data_args.eval_instruct_data, "r") as f: + for line in f: + data = json.loads(line) + last_message = data["interactions"][-1] + if "content" in last_message: + expected_loss_strings.append(last_message["content"]) + elif "tool_calls" in last_message: + tool_calls = last_message["tool_calls"] + arguments = tool_calls[0]["function"]["arguments"] + string = [ + { + "name": call["function"]["name"], + "arguments": json.loads(arguments), + } + for call in tool_calls + ] + expected_loss_strings.append(json.dumps(string)) + + assert expected_loss_strings == all_loss_strings + + +def test_data_weighting(): + data_args = DataArgs( + data="", + instruct_data="", + eval_instruct_data="tests/fixtures/sample_instruct.jsonl", + ) + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + jsonl_file = Path(data_args.eval_instruct_data) + + with jsonl_file.open() as file_handle: + data = json.loads(next(file_handle)) + + token_sample = encode(data, instruct_tokenizer, SampleType.INSTRUCT) + + data_weight_0 = copy.deepcopy(data) + data_weight_0["interactions"][-1]["weight"] = 0 + + token_sample_weight_0 = encode( + data_weight_0, instruct_tokenizer, SampleType.INSTRUCT + ) + + data_weight_1 = copy.deepcopy(data) + data_weight_1["interactions"][-1]["weight"] = 1 + + token_sample_weight_1 = encode( + data_weight_1, instruct_tokenizer, SampleType.INSTRUCT + ) + + assert ( + token_sample.tokens + == token_sample_weight_0.tokens + == token_sample_weight_1.tokens + ) + assert token_sample.masks == token_sample_weight_1.masks + assert token_sample.masks != token_sample_weight_0.masks + assert not any(token_sample_weight_0.masks) + + +def test_eval_dataloader(): + batch_size = 1 + + data_args = DataArgs( + data="", + instruct_data="", + eval_instruct_data="tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl", + ) + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + # make sure that for every seq len the same data is seen + for world_size in [1, 2, 8]: + for seq_len in [10, 100, 1000, 10000]: + x_sums = [] + y_sums = [] + y_masks = [] + + data_loaders = [] + for rank in range(world_size): + data_loaders.append( + build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=None, + rank=rank, + world_size=world_size, + is_eval=True, + ) + ) + + for data_loader in data_loaders: + for batch in data_loader: + mask = ( + np.asarray(batch.y_mask, int) + if batch.y_mask is not None + else np.ones_like(batch.x) + ) + x_sums.append((batch.x * mask).sum()) + y_sums.append((batch.y * mask).sum()) + y_masks.append(mask.sum()) + + assert len(batch.x) == len(mask) == len(batch.y) == seq_len + + assert sum(x_sums) == 71404835 + assert sum(y_sums) == 71404795 + assert sum(y_masks) == 5538 + + +def test_shuffle_data(): + instruct_tokenizer = MockInstructTokenizerBaseBase() + + data_args = DataArgs(data="", instruct_data="", eval_instruct_data="") + + data_file = Path("tests/fixtures/sample_instruct_long_1.jsonl") + + dataset_iterator = get_dataset_iterator( + source=DataFile(path=data_file, sample_type=SampleType.INSTRUCT), + instruct_args=data_args.instruct, + instruct_tokenizer=instruct_tokenizer, + rank=0, + world_size=1, + is_finite=False, + seed=0, + shuffle_at_epoch=True, + ) + + with data_file.open() as f: + lines = f.readlines() + lines = [ + encode( + json.loads(line), + instruct_tokenizer=instruct_tokenizer, + as_type=SampleType.INSTRUCT, + ) + for line in lines + ] + prev_lines = stringify(lines) + + num_lines = len(prev_lines) + + samples = [] + # run 4 epochs + for i in range(4 * num_lines): + samples.append(next(dataset_iterator)) + + if (i + 1) % num_lines == 0: + # epoch finished! + # check that order is different but all lines have the same hash + lines = stringify(samples) + assert lines != prev_lines, "No shuffling - make sure dataset is shuffled!" + assert sorted(lines) == sorted( + prev_lines + ), "datasets need to match at every epoch" + + prev_lines = lines + samples = [] + + +@pytest.mark.parametrize("world_size", [1, 2]) +def test_shuffle_data_same_as_no_shuffle(world_size): + spawn_for_all_world_sizes( + _check_shuffle_data_same_as_no_shuffle, + world_sizes=[world_size], + deterministic=True, + ) + + +def _check_shuffle_data_same_as_no_shuffle( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, +): + instruct_tokenizer = MockInstructTokenizerBaseBase() + + instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) + + data_args = DataArgs( + data="tests/fixtures/sample_pretrain_1.jsonl:1.0,tests/fixtures/sample_pretrain_2.jsonl:1.0", + instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", + instruct=instruct, + ) + + sources, _ = parse_data_sources( + pretrain_data=data_args.data, instruct_data=data_args.instruct_data + ) + + seed = 0 + rng = get_rng(seed, rank) + + for source in sources: + jsonl_files = source.jsonl_files + chunk_dataset = source.sample_type == SampleType.INSTRUCT + + for jsonl_file in jsonl_files: + samples = [[], []] + for shuffle in [True, False]: + print(jsonl_file) + if shuffle: + iterator = preload_and_yield( + jsonl_file, + chunk_dataset=chunk_dataset, + rank=rank, + world_size=world_size, + rng=rng, + instruct_tokenizer=instruct_tokenizer, + sample_type=source.sample_type, + ) + else: + iterator = lazy_load_and_yield( + jsonl_file, + rank=rank, + world_size=world_size, + instruct_tokenizer=instruct_tokenizer, + sample_type=source.sample_type, + ) + + for tokens in iterator: + samples[shuffle].append(tokens) + + strings_0 = sorted(stringify(samples[0])) + strings_1 = sorted(stringify(samples[1])) + + assert strings_0 == strings_1 diff --git a/tests/test_mixed_precision.py b/tests/test_mixed_precision.py new file mode 100644 index 0000000..c00c335 --- /dev/null +++ b/tests/test_mixed_precision.py @@ -0,0 +1,136 @@ +from pathlib import Path + +import pytest +import torch + +from finetune.args import LoraArgs +from finetune.loss import compute_loss_with_mask +from finetune.mixed_precision import ( + downcast_mixed_precision, + prepare_mixed_precision, + upcast_mixed_precision, +) +from finetune.wrapped_model import load_model +from tests.test_utils import MODEL_PATH, get_dataloader, setup_mp_test_dist + +from .test_utils import spawn_for_all_world_sizes + + +@pytest.mark.parametrize( + ("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)] +) +def test_mixed_precision(world_size, enable_lora): + spawn_for_all_world_sizes( + _check_mixed_precision, + world_sizes=[world_size], + args=[enable_lora], + deterministic=True, + ) + + +def _check_mixed_precision( + rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + seq_len = 100 + + folder = Path(MODEL_PATH) + # mixed precision + param_dtype = torch.bfloat16 + optim_dtype = torch.float32 + + model = load_model( + folder=folder, + lora=LoraArgs(enable=enable_lora), + checkpoint=True, + param_dtype=param_dtype, + ) + + optimizer = torch.optim.AdamW(model.parameters()) + + # initialize mixed precision training for TP + prepare_mixed_precision( + model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype + ) + + data_loader = get_dataloader(seq_len=seq_len) + + # ensure every parameter that requires a grad has a _mp_param of optim_dtype precision + for param in model.parameters(): + assert param.dtype == param_dtype + if param.requires_grad: + assert param._mp_param.dtype == optim_dtype + assert ( + param._mp_param.tolist() == param.data.to(optim_dtype).tolist() + ), "mp param has to match param in optim dtype precision" + else: + assert not hasattr(param, "_mp_param") + + # test three train steps + for _ in range(3): + + optimizer.zero_grad() + + # micro-batching + for _ in range(2): + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda(non_blocking=True) + y = torch.from_numpy(batch.y).cuda(non_blocking=True) + y_mask = ( + torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + if batch.y_mask is not None + else None + ) + + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + mb_loss = compute_loss_with_mask(output, y, y_mask) + mb_loss.backward() + + upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) + + # ensure all params are upcasted correctly and mp param equals param + param_sum = 0 + for param in model.parameters(): + if param.requires_grad: + assert param.dtype == optim_dtype, param.dtype + assert ( + param._mp_param.tolist() == param.data.tolist() + ), "mp param and param should point to the same data" + assert param.grad.dtype == optim_dtype + assert param._temp.dtype == param_dtype + param_sum += param.data.float().abs().sum() + else: + assert param.dtype == param_dtype + + optimizer.step() + + # ensure that after optimizer step params are still in optim dtype precision + new_param_sum = 0 + for param in model.parameters(): + if param.requires_grad: + assert param.dtype == optim_dtype + assert param._mp_param.dtype == optim_dtype + assert param.grad.dtype == optim_dtype + new_param_sum += param.data.float().abs().sum() + else: + assert param.dtype == param_dtype + + assert new_param_sum != param_sum, "Make sure parameters are updated" + + downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) + + # ensure that before new forward pass params are downcasted to param dtype + for param in model.parameters(): + assert param.dtype == param_dtype + if param.requires_grad: + assert param._mp_param.dtype == optim_dtype + assert param.grad.dtype == param_dtype + assert ( + param._mp_param.to(param_dtype).tolist() == param.data.tolist() + ), "mp param has to match param in optim dtype precision" diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..ee6af5b --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,550 @@ +import tempfile +from pathlib import Path +from typing import Dict + +import pytest +import torch + +from finetune.args import LoraArgs +from finetune.checkpointing import Checkpointer +from finetune.loss import compute_loss_with_mask +from finetune.mixed_precision import ( + downcast_mixed_precision, + prepare_mixed_precision, + upcast_mixed_precision, +) +from finetune.utils import TrainState +from finetune.wrapped_model import load_model +from model.transformer import ( + LoRALinear, +) +from tests.test_utils import ( + MODEL_PATH, + get_dataloader, + is_float_equal, + setup_mp_test_dist, +) + +from .test_utils import spawn_for_all_world_sizes + +torch.backends.cudnn.deterministic = True # use deterministic algorithms +torch.backends.cudnn.benchmark = False # disable cuDNN benchmark + + +@pytest.mark.parametrize( + ("world_size", "enable_lora", "dtype"), + [ + (1, False, torch.float32), + (1, True, torch.float32), + (2, False, torch.float32), + (2, True, torch.float32), + (1, False, torch.bfloat16), + (1, True, torch.bfloat16), + (2, False, torch.bfloat16), + (2, True, torch.bfloat16), + ], +) +def test_weights_loading(world_size, enable_lora, dtype): + spawn_for_all_world_sizes( + _check_weights_loading, + world_sizes=[world_size], + args=[enable_lora, dtype], + deterministic=True, + ) + + +def _check_weights_loading( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, + enable_lora: bool, + dtype: torch.dtype, +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=enable_lora), + checkpoint=True, + param_dtype=dtype, + ) + + # add hook so that LoRA weights are automatically merged: + def register_merge_lora_hook(m: torch.nn.Module): + def merge_lora( + m: torch.nn.Module, destination: Dict[str, torch.Tensor], prefix: str, *args + ): + weight = m.merge_weight() + destination[prefix + "weight"] = weight + + if isinstance(m, LoRALinear): + m._merge_lora_handle = m._register_state_dict_hook(merge_lora) + + model.apply(register_merge_lora_hook) + + if world_size > 1: + with model.summon_full_params(model, writeback=True): + states = { + k: v + for k, v in model.state_dict().items() + if "lora" not in k and "frozen" not in k + } + else: + states = { + k: v + for k, v in model.state_dict().items() + if "lora" not in k and "frozen" not in k + } + + EXP_PARAM_SUM = 308.9932 if dtype == torch.float32 else 308.0 + params = sum([v.sum() for v in states.values()]).item() + + # LoRA is equal to no LoRA as LoRA weights should be init to 0 + assert is_float_equal(params, EXP_PARAM_SUM), params + + if enable_lora: + lora_B_params = [ + v.float().abs().sum() for k, v in model.named_parameters() if "lora_B" in k + ] + + assert len(lora_B_params) > 0 + assert sum(lora_B_params) == 0, "Lora_B should always be zero init" + + lora_A_params = [ + v.float().abs().sum() for k, v in model.named_parameters() if "lora_A" in k + ] + + assert len(lora_A_params) > 0 + assert sum(lora_A_params) > 0, "Lora_A should init to non-zero values" + + +@pytest.mark.parametrize( + ("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)] +) +def test_fsdp_logits_and_loss(world_size, enable_lora): + spawn_for_all_world_sizes( + _check_fsdp_logits_and_loss, + world_sizes=[world_size], + args=[enable_lora], + deterministic=True, + ) + + +def _check_fsdp_logits_and_loss( + rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + seq_len = 100 + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=enable_lora), + checkpoint=True, + param_dtype=torch.bfloat16, + ) + # By seting equal rank and world_size we can assure that both ranks see the same data and hence the average + data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) + + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda(non_blocking=True) + y = torch.from_numpy(batch.y).cuda(non_blocking=True) + y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + + # forward / backward + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + # check logits + # logits should be the same for LoRA and non-LoRA + assert output.shape == (seq_len, model.args.vocab_size) + output_sum = output.abs().float().sum().item() + + EXP_OUTPUT_WORLD_1 = 162617.625 + + assert is_float_equal(output_sum, EXP_OUTPUT_WORLD_1, precision=1e1), output_sum + + # check loss is the same for all + # loss should be the same for LoRA and non-LoRA + mb_loss = compute_loss_with_mask(output, y, y_mask) + + EXPECTED_LOSS = 10.408413887023926 + + assert is_float_equal(mb_loss.item(), EXPECTED_LOSS), mb_loss.item() + + +@pytest.mark.parametrize( + ("world_size", "dtype"), + [(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], +) +def test_fsdp_grads_non_lora(world_size, dtype): + spawn_for_all_world_sizes( + _check_fsdp_grads_non_lora, + world_sizes=[world_size], + deterministic=True, + args=[dtype], + ) + + +def _check_fsdp_grads_non_lora( + rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + seq_len = 2048 + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=False), + checkpoint=True, + param_dtype=dtype, + ) + # same world_size to check for equality + data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) + + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda(non_blocking=True) + y = torch.from_numpy(batch.y).cuda(non_blocking=True) + y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + + # forward / backward + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + mb_loss = compute_loss_with_mask(output, y, y_mask) + mb_loss.backward() + + num_grad_params = sum([p.grad.numel() for p in model.parameters()]) + + assert (4301120 // world_size) == num_grad_params, num_grad_params + + torch.distributed.barrier() + + sharded_flat_grads = sum( + [p.grad.float().abs().sum().item() for p in model.parameters()] + ) + + print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}") + + EXP_GRAD_WORLD_2_RANK_0 = 95.45827150344849 + EXP_GRAD_WORLD_2_RANK_1 = 86.09188461303711 + EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1 + + if world_size == 1: + assert is_float_equal( + sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1 + ), sharded_flat_grads + elif world_size == 2 and rank == 0: + assert is_float_equal( + sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1 + ), sharded_flat_grads + elif world_size == 2 and rank == 1: + assert is_float_equal( + sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1 + ), sharded_flat_grads + + +@pytest.mark.parametrize( + ("world_size", "dtype"), + [(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], +) +def test_fsdp_grads_lora(world_size, dtype): + spawn_for_all_world_sizes( + _check_fsdp_grads_lora, + world_sizes=[world_size], + deterministic=True, + args=[dtype], + ) + + +def _check_fsdp_grads_lora( + rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + seq_len = 2048 + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=True), + checkpoint=True, + param_dtype=dtype, + ) + # same world_size to check for equality + data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) + + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda(non_blocking=True) + y = torch.from_numpy(batch.y).cuda(non_blocking=True) + y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + + # forward / backward + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + mb_loss = compute_loss_with_mask(output, y, y_mask) + mb_loss.backward() + + num_grad_params = sum( + [p.grad.numel() for p in model.parameters() if p.grad is not None] + ) + + assert (40960 // world_size) == num_grad_params, num_grad_params + + torch.distributed.barrier() + + sharded_flat_grads = sum( + [ + p.grad.float().abs().sum().item() + for p in model.parameters() + if p.grad is not None + ] + ) + + print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}") + + EXP_GRAD_WORLD_2_RANK_0 = 3.0742580661177635 + EXP_GRAD_WORLD_2_RANK_1 = 3.074301045779139 + EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1 + + if world_size == 1: + assert is_float_equal( + sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1 + ), sharded_flat_grads + elif world_size == 2 and rank == 0: + assert is_float_equal( + sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1 + ), sharded_flat_grads + elif world_size == 2 and rank == 1: + assert is_float_equal( + sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1 + ), sharded_flat_grads + + +@pytest.mark.parametrize( + ("world_size", "dtype"), + [(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], +) +def test_grad_update_lora(world_size, dtype): + spawn_for_all_world_sizes( + _check_grad_update_lora, + world_sizes=[world_size], + args=[dtype], + deterministic=True, + ) + + +def _check_grad_update_lora( + rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + seq_len = 1000 + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=True), + checkpoint=True, + param_dtype=dtype, + ) + optimizer = torch.optim.AdamW(model.parameters()) + + data_loader = get_dataloader(seq_len=seq_len) + + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda(non_blocking=True) + y = torch.from_numpy(batch.y).cuda(non_blocking=True) + y_mask = ( + torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + if batch.y_mask is not None + else None + ) + + # forward / backward + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + mb_loss = compute_loss_with_mask(output, y, y_mask) + mb_loss.backward() + + lora_weight_sum = 0 + non_lora_weight_sum = 0 + for name, param in model.named_parameters(): + if "lora" in name or "norm" in name: + assert param.grad is not None, name + lora_weight_sum += param.data.float().abs().sum() + else: + assert param.grad is None, name + non_lora_weight_sum += param.data.float().abs().sum() + + # update weights + optimizer.step() + + new_lora_weight_sum = 0 + new_non_lora_weight_sum = 0 + for name, param in model.named_parameters(): + if "lora" in name or "norm" in name: + assert param.grad is not None, name + new_lora_weight_sum += param.data.float().abs().sum() + else: + assert param.grad is None, name + new_non_lora_weight_sum += param.data.float().abs().sum() + + # make sure that LoRA weights changed, but non-LoRA weights stayed the same + assert not is_float_equal( + new_lora_weight_sum, lora_weight_sum, 1e-4 + ), f"New: {new_lora_weight_sum}, Old: {lora_weight_sum}" + assert is_float_equal( + new_non_lora_weight_sum, non_lora_weight_sum, 1e-4 + ), f"New: {new_non_lora_weight_sum}, Old: {non_lora_weight_sum}" + + +@pytest.mark.parametrize( + ("enable_lora", "param_dtype"), + [ + (False, torch.float32), + (True, torch.float32), + (False, torch.bfloat16), + (True, torch.bfloat16), + ], +) +def test_grads_fsdp_mp(enable_lora, param_dtype): + with tempfile.TemporaryDirectory() as tmpdirname: + for world_size in [1, 2]: + spawn_for_all_world_sizes( + _check_grads_fsdp_mp, + world_sizes=[world_size], + deterministic=True, + args=[tmpdirname, enable_lora, param_dtype], + ) + + w1_sd = torch.load(Path(tmpdirname) / Path("params_w1.pt"), map_location="cpu") + w2_sd = torch.load(Path(tmpdirname) / Path("params_w2.pt"), map_location="cpu") + + for k in w1_sd.keys(): + assert w1_sd[k].shape == w2_sd[k].shape, k + atol = 10 if param_dtype == torch.float32 else 100 + assert (w1_sd[k] - w2_sd[k]).sum().abs().item() < atol + + +def _check_grads_fsdp_mp( + rank: int, + world_size: int, + filename: str, + filename_rpc: str, + tmpdirname: str, + enable_lora: bool, + param_dtype: torch.dtype, +): + model_parallel = 1 + setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) + seq_len = 4096 + + optim_dtype = torch.float32 + + folder = Path(MODEL_PATH) + model = load_model( + folder=folder, + lora=LoraArgs(enable=enable_lora), + checkpoint=True, + param_dtype=param_dtype, + ) + + # high learning rate to show differences + optimizer = torch.optim.AdamW(model.parameters(), lr=0.1) + + # mock a train state that has done three steps + steps = 4 + state = TrainState(max_steps=steps) + + # mock run_dir as we won't save anything in this test + run_dir = Path(tmpdirname) + + checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) + + # make sure the same data is seen + dataloaders = [ + get_dataloader(seq_len=seq_len, rank=rank + i, world_size=2) + for i in range(2 - world_size + 1) + ] + + prepare_mixed_precision( + model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype + ) + + for _ in range(steps): + state.start_step() + optimizer.zero_grad() + + for data_loader in dataloaders: + torch.manual_seed(0) + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda() + y = torch.from_numpy(batch.y).cuda() + y_mask = ( + torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + if batch.y_mask is not None + else None + ) + + # forward / backward + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + + mb_loss = compute_loss_with_mask(output, y, y_mask) + mb_loss.backward() + + assert model.params[0].dtype == param_dtype + + print(f"rank: {rank}, world_size: {world_size}, x: {x.abs().sum()}") + print(f"rank: {rank}, world_size: {world_size}, y: {y.abs().sum()}") + print(f"rank: {rank}, world_size: {world_size}, x shape: {x.shape}") + + if y_mask is not None: + print( + f"rank: {rank}, world_size: {world_size}, y_mask: {y_mask.abs().sum()}" + ) + print(f"rank: {rank}, world_size: {world_size}, loss: {mb_loss}") + + for p in model.parameters(): + if p.requires_grad: + assert p.grad is not None + p.grad.div_(len(dataloaders)) + + max_norm = 1.0 + model.clip_grad_norm_(max_norm=max_norm) + + upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) + + optimizer.step() + + downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) + + save_dict = checkpointer.retrieve_save_states( + save_only_lora=enable_lora, save_dtype=torch.float32 + ) + + path = "params_w1.pt" if world_size == 1 else "params_w2.pt" + torch.save(save_dict, Path(tmpdirname) / Path(path)) diff --git a/tests/test_train_loop.py b/tests/test_train_loop.py new file mode 100644 index 0000000..5505b4f --- /dev/null +++ b/tests/test_train_loop.py @@ -0,0 +1,93 @@ +import os +import tempfile +from contextlib import ExitStack +from pathlib import Path + +import pytest +import safetensors +import torch + +from finetune.args import LoraArgs, OptimArgs, TrainArgs +from finetune.data.args import DataArgs, InstructArgs +from tests.test_utils import DATA_PATH, EVAL_DATA_PATH, MODEL_PATH, setup_mp_test_dist +from train import _train + +from .test_utils import spawn_for_all_world_sizes + + +def file_size_and_md5(file_path): + # Check if the file exists + if not os.path.isfile(file_path): + return "Error: File not found" + + # Get the size of the file + file_size = os.path.getsize(file_path) + + # Open the file in binary mode + state_dict = safetensors.torch.load_file(file_path) + md5_sum = sum(v.abs().sum().item() for v in state_dict.values()) + + return file_size, md5_sum + + +@pytest.mark.parametrize("enable_lora", [False, True]) +def test_integration(enable_lora): + torch.backends.cudnn.deterministic = True # use deterministic algorithms + torch.backends.cudnn.benchmark = False # disable cuDNN benchmark + + instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) + + data_args = DataArgs( + data="", + instruct_data=DATA_PATH, + eval_instruct_data=EVAL_DATA_PATH, + instruct=instruct, + ) + + model_path = MODEL_PATH + optim_args = OptimArgs(lr=0.01, weight_decay=0.1, pct_start=0.0) + + with tempfile.TemporaryDirectory() as tmpdirname: + args = TrainArgs( + data=data_args, + model_id_or_path=model_path, + run_dir=tmpdirname, + seed=0, + optim=optim_args, + max_steps=4, + num_microbatches=1, + lora=LoraArgs(enable=enable_lora), + ckpt_only_lora=enable_lora, + checkpoint=True, + no_eval=False, + ) + spawn_for_all_world_sizes( + _run_dummy_train, + world_sizes=[2], + deterministic=True, + args=[args], + ) + + prefix = "lora" if enable_lora else "consolidated" + ckpt_path = Path(tmpdirname) / Path( + f"checkpoints/checkpoint_00000{args.max_steps}/consolidated/{prefix}.safetensors" + ) + assert ckpt_path.exists() + + file_size, hash = file_size_and_md5(ckpt_path) + + EXPECTED_FILE_SIZE = [8604200, 84760] + + EXPECTED_HASH = [50515.5, 1296.875] + + assert file_size == EXPECTED_FILE_SIZE[enable_lora], file_size + assert abs(hash - EXPECTED_HASH[enable_lora]) < 1e-2, hash + + +def _run_dummy_train( + rank: int, world_size: int, filename: str, filename_rpc: str, args: TrainArgs +): + setup_mp_test_dist(rank, world_size, filename, 1, seed=0) + + with ExitStack() as exit_stack: + _train(args, exit_stack) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b434757 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,227 @@ +import os +import tempfile +from datetime import timedelta +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from finetune.data.args import DataArgs, InstructArgs +from finetune.data.data_loader import build_data_loader +from finetune.distributed import get_rank, get_world_size +from finetune.utils import set_random_seed + + +def is_float_equal(a, b, precision=5e-3): + return abs(a - b) < precision + + +MODEL_PATH = os.getenv("DUMMY_MODEL") +assert MODEL_PATH != "", "Provide a path to a dummy model" +DATA_PATH = "tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1" +EVAL_DATA_PATH = "tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl" + + +# Model parallel group that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +# Pipeline parallel group that the current rank belongs to. +_PIPELINE_PARALLEL_GROUP = None + +_PIPELINE_PARALLEL_RANKS = None + + +def rmf(filename: str) -> None: + """Remove a file like rm -f.""" + try: + os.remove(filename) + except FileNotFoundError: + pass + + +def test_runner( + rank: int, test_func: Callable, deterministic: bool = False, *args: List[Any], **kwargs: Dict[str, Any] +) -> None: + # At this point we're in a new process, torch options need to be set again + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(1357) + + test_func(rank, *args, **kwargs) + + +def spawn_for_all_world_sizes( + test_func: Callable, world_sizes: List[int] = [], args: Any = [], deterministic: bool = False +) -> None: + for world_size in world_sizes: + _, filename = tempfile.mkstemp() + _, filename_rpc = tempfile.mkstemp() + + try: + torch.multiprocessing.spawn( + test_runner, + args=(test_func, deterministic, world_size, filename, filename_rpc, *args), + nprocs=world_size, + join=True, + ) + finally: + rmf(filename) + rmf(filename_rpc) + +def initialize_model_parallel( + model_parallel_size_: int, + pipeline_length: int = 1, + *, + model_parallel_backend: Optional[str] = None, + pipeline_backend: Optional[str] = None, + ddp_backend: Optional[str] = None +) -> None: + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel groups as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + model_parallel_size = int(min(model_parallel_size_, world_size)) + rank = torch.distributed.get_rank() + + data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) + + if torch.distributed.get_rank() == 0: + print("> initializing model parallel with size {}".format(model_parallel_size_)) + print("> initializing ddp with size {}".format(data_parallel_size)) + print("> initializing pipeline with size {}".format(pipeline_length)) + + groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) + + found = torch.where(groups == rank) + assert all(len(x) == 1 for x in found) + found = [x[0] for x in found] + + # Build the data parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + for j in range(pipeline_length): + for k in range(model_parallel_size): + group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) + if j == found[1] and k == found[2]: + _DATA_PARALLEL_GROUP = group + + # Build the model parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + for i in range(data_parallel_size): + for j in range(pipeline_length): + group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) + if i == found[0] and j == found[1]: + _MODEL_PARALLEL_GROUP = group + + global _PIPELINE_PARALLEL_GROUP + assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized" + global _PIPELINE_PARALLEL_RANKS + assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized" + for i in range(data_parallel_size): + for k in range(model_parallel_size): + ranks = groups[i, :, k].tolist() + group = torch.distributed.new_group(ranks, backend=pipeline_backend) + if i == found[0] and k == found[2]: + _PIPELINE_PARALLEL_GROUP = group + _PIPELINE_PARALLEL_RANKS = ranks + + +def setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0): + dist_init_for_testing(rank, world_size, filename) + torch.cuda.set_device(rank) + + # Init NCCL + backend = "nccl" + initialize_model_parallel( + model_parallel, + model_parallel_backend=backend, + pipeline_backend=backend, + ddp_backend=backend, + ) + + set_random_seed(seed) + + if torch.cuda.is_available(): + torch.set_default_tensor_type(torch.cuda.FloatTensor) # type: ignore + + +def dist_init_for_testing( + rank: int, world_size: int, filename: str, filename_rpc: str = "", timeout: int = 30 +): + """ + Same than fairscale testing.dist_init but without rpc + + filename_rpc is here to keep same signature than fairscale init + """ + + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + url = "file://" + filename + + backend = "nccl" if torch.cuda.is_available() else "gloo" + + if backend == "nccl" and torch.cuda.device_count() < world_size: + raise RuntimeError( + f"Requested world size {world_size} cannot be reached on this machine, not enough GPUs {torch.cuda.device_count()}" + ) + + dist.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + init_method=url, + timeout=timedelta(seconds=timeout), + ) + + +def get_dataloader( + seed: int = 0, + seq_len: int = 10000, + rank: Optional[int] = None, + world_size: Optional[int] = None, +): + batch_size = 1 + rank = rank if rank is not None else get_rank() + world_size = world_size if world_size is not None else get_world_size() + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) + + data_args = DataArgs( + data="", + instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", + instruct=instruct, + ) + data_loader = build_data_loader( + instruct_tokenizer, + data_args, + batch_size, + seq_len, + seed=seed, + rank=rank, + world_size=world_size, + is_eval=False, + ) + return data_loader diff --git a/train.py b/train.py new file mode 100644 index 0000000..ed74ef9 --- /dev/null +++ b/train.py @@ -0,0 +1,323 @@ +import dataclasses +import logging +import os +import pprint +from contextlib import ExitStack +from pathlib import Path +from typing import TYPE_CHECKING + +import fire +import torch.cuda +import torch.distributed as dist +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from torch.optim import AdamW, lr_scheduler + +from finetune.args import TrainArgs +from finetune.checkpointing import Checkpointer +from finetune.data.data_loader import build_data_loader +from finetune.distributed import ( + BACKEND, + avg_aggregate, + get_rank, + get_world_size, + is_torchrun, + set_device, +) +from finetune.eval import evaluate +from finetune.loss import compute_loss_with_mask +from finetune.mixed_precision import ( + downcast_mixed_precision, + prepare_mixed_precision, + upcast_mixed_precision, +) +from finetune.monitoring.metrics_logger import ( + MetricsLogger, + eval_log_msg, + get_eval_logs, + get_train_logs, + train_log_msg, +) +from finetune.monitoring.utils import set_logger +from finetune.utils import ( + TrainState, + logged_closing, + set_random_seed, +) +from finetune.wrapped_model import load_model + +if TYPE_CHECKING: + from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase + +logger = logging.getLogger("train") + + +def main_logger_info(message: str) -> None: + if get_rank() == 0: + logger.info(message) + + +def train(config: str): + args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False) + print(f"args: {args}") + set_logger(logging.INFO) + + with ExitStack() as exit_stack: + _train(args, exit_stack) + logger.info("Closed everything!") + + +def _train( + args: TrainArgs, + exit_stack: ExitStack, +): + # 1. Initial setup and checks + set_random_seed(args.seed) + + # Init NCCL + if "LOCAL_RANK" in os.environ: + set_device() + logger.info("Going to init comms...") + + dist.init_process_group(backend=BACKEND) + else: + logger.error( + "PyTorch environment is not correctly initialized. This message should only be displayed when testing." + ) + + # 2. Init run dir + main_logger_info(f"Run dir: {args.run_dir}") + run_dir = Path(args.run_dir) + + if is_torchrun(): + if run_dir.exists(): + raise RuntimeError( + f"Run dir {run_dir} already exists. Make sure to either rename `run_dir` or remove {run_dir}." + ) + + dist.barrier() + run_dir.mkdir(exist_ok=True, parents=True) + + args_path = run_dir / "args.yaml" + if not args_path.exists(): + args.save(args_path) + + main_logger_info(f"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}") + + # 3. Get loggers + metrics_logger: MetricsLogger = MetricsLogger( + run_dir, + tag="train", + is_master=get_rank() == 0, + wandb_args=args.wandb, + mlflow_args=args.mlflow, + config=dataclasses.asdict(args), + ) + exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger")) + + eval_logger: MetricsLogger = MetricsLogger( + run_dir, + tag="eval", + is_master=get_rank() == 0, + wandb_args=args.wandb, + mlflow_args=args.mlflow, + config=dataclasses.asdict(args), + ) + exit_stack.enter_context(logged_closing(eval_logger, "eval_logger")) + + # 5. Potentially download model + if Path(args.model_id_or_path).is_dir(): + model_folder = Path(args.model_id_or_path) + else: + raise ValueError( + "Invalid folder path. Please set `args.initial_model` to a valid folder path." + ) + + # 6. Load function calling instruct tokenizer + instruct_tokenizer: InstructTokenizerBase = MistralTokenizer.v3().instruct_tokenizer # type: ignore + + # 7. Load data loaders + data_loader = build_data_loader( + instruct_tokenizer=instruct_tokenizer, + args=args.data, + seq_len=args.seq_len, + batch_size=args.batch_size, + seed=args.seed, + rank=get_rank(), # DDP rank + world_size=get_world_size(), # DDP world_size + is_eval=False, + ) + + if not args.no_eval: + assert ( + args.data.eval_instruct_data != "" + ), "Either set `no_eval` to True or provide evaluation samples under `data.eval_instruct_data`" + + eval_data_loader = build_data_loader( + instruct_tokenizer=instruct_tokenizer, + args=args.data, + seq_len=args.seq_len, + batch_size=args.batch_size, + seed=None, + rank=get_rank(), # DDP rank + world_size=get_world_size(), # DDP world_size + is_eval=True, + ) + # pre-load all eval tokens + eval_batches = list(eval_data_loader) + + # 8. Load model + # Define mixed precision + param_dtype = torch.bfloat16 + optim_dtype = torch.float32 + + assert args.lora is not None, "`args.lora` should be set to a valid value." + + model = load_model( + folder=model_folder, + lora=args.lora, + checkpoint=args.checkpoint, + param_dtype=param_dtype, + ) + + # 9. Load optimizer + optimizer = AdamW( + model.parameters(), + lr=args.optim.lr, + betas=(0.9, 0.95), + eps=1e-08, + weight_decay=args.optim.weight_decay, + ) + + scheduler = lr_scheduler.OneCycleLR( + optimizer, + max_lr=args.optim.lr, + total_steps=args.max_steps, + pct_start=args.optim.pct_start, + ) + + state = TrainState(args.max_steps) + + # 10. Initialize checkpointer + checkpointer = Checkpointer( + model=model, + state=state, + run_dir=run_dir, + optimizer=optimizer, + num_ckpt_keep=args.num_ckpt_keep, + ) + # 11. Prepare mixed precision + prepare_mixed_precision( + model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype + ) + + # 12. train! + model.train() + torch.cuda.empty_cache() + + while state.step < args.max_steps: + state.start_step() + is_last_step = state.step == args.max_steps + + optimizer.zero_grad() + + loss = torch.tensor([0.0], device="cuda") + n_batch_tokens: int = 0 + + for i in range(args.num_microbatches): + # batch + batch = next(data_loader) + + x = torch.from_numpy(batch.x).cuda(non_blocking=True) + y = torch.from_numpy(batch.y).cuda(non_blocking=True) + y_mask = ( + torch.from_numpy(batch.y_mask).cuda(non_blocking=True) + if batch.y_mask is not None + else None + ) + + # forward / backward + output = model( + input_ids=x, + seqlens=batch.sizes, + ) + mb_loss = compute_loss_with_mask(output, y, y_mask) + + mb_loss.backward() + + loss += mb_loss.detach() + n_batch_tokens += x.numel() + + if i < args.num_microbatches - 1: + # synchronize CUDA to re-run backward + assert args.num_microbatches > 1 # should not happen + torch.cuda.synchronize() + + if args.num_microbatches > 1: + loss /= args.num_microbatches + for p in model.parameters(): + if p.requires_grad: + assert p.grad is not None + p.grad.div_(args.num_microbatches) + + # upcast params for optimizer update + upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) + + # clip gra d norm + model.clip_grad_norm_(max_norm=args.max_norm) + + # optimizer step + optimizer.step() + + # downcast params for forward & backward + downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) + + last_lr = scheduler.get_last_lr()[0] + scheduler.step() + + # Host sync + loss_item = loss.item() + avg_loss = avg_aggregate(loss_item) + + if not args.no_eval and ( + (args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step + ): + # write perplexity to state + evaluate(model, eval_batches, state) + + eval_logs = get_eval_logs( + state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss + ) + + main_logger_info(eval_log_msg(eval_logs)) + eval_logger.log(eval_logs, step=state.step) + + # Timing + state.end_step(n_batch_tokens) + + if state.step % args.log_freq == 0: + train_logs = get_train_logs( + state, + avg_loss, + last_lr, + torch.cuda.max_memory_allocated(), + torch.cuda.memory_allocated(), + args, + ) + main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss)) + metrics_logger.log(train_logs, step=state.step) + + if not args.no_ckpt and ( + (args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step + ): + checkpointer.save_checkpoint( + save_only_lora=args.ckpt_only_lora, + dtype=param_dtype, + instruct_tokenizer=instruct_tokenizer, + ) + + main_logger_info("done!") + + +if __name__ == "__main__": + """See README.md for usage.""" + fire.Fire(train) diff --git a/tutorials/mistral_finetune_7b.ipynb b/tutorials/mistral_finetune_7b.ipynb new file mode 100644 index 0000000..1b30920 --- /dev/null +++ b/tutorials/mistral_finetune_7b.ipynb @@ -0,0 +1,3135 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm", + "gpuType": "A100" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "c06bbba05e10462d993f3e7e6f932cf1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6f734c35284341d891a44694ddc55b2c", + "IPY_MODEL_193be53200ab436a967f1ea4807053e2", + "IPY_MODEL_85a2415f14284237875b349b4c414e21" + ], + "layout": "IPY_MODEL_3ab931b2fcc0493ca71923ebc37127c7" + } + }, + "6f734c35284341d891a44694ddc55b2c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f49472d0536f4011be17902f9e827807", + "placeholder": "​", + "style": "IPY_MODEL_2f341cb76f254d0da913faec6a82f762", + "value": "Fetching 3 files: 100%" + } + }, + "193be53200ab436a967f1ea4807053e2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_24ebf383723f4de494f9808b41222751", + "max": 3, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_0c17aa90672046c9bd2f293b1a998b46", + "value": 3 + } + }, + "85a2415f14284237875b349b4c414e21": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_af89033247b34da2a5cded73b0beade2", + "placeholder": "​", + "style": "IPY_MODEL_0ea5f5e6ab26484ab22dcf5576f796d1", + "value": " 3/3 [01:49<00:00, 109.92s/it]" + } + }, + "3ab931b2fcc0493ca71923ebc37127c7": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f49472d0536f4011be17902f9e827807": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2f341cb76f254d0da913faec6a82f762": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "24ebf383723f4de494f9808b41222751": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0c17aa90672046c9bd2f293b1a998b46": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "af89033247b34da2a5cded73b0beade2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0ea5f5e6ab26484ab22dcf5576f796d1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "946ce9afeddb4da5a36e81e5ada9d957": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_95782b84af1c4014ae04c9e6c9131cbe", + "IPY_MODEL_2332101edec848219c3b0c6026c2a722", + "IPY_MODEL_a2638b3f10a24de99bb940dcd150ab53" + ], + "layout": "IPY_MODEL_1c78955b41ba4845931a250f16b753b5" + } + }, + "95782b84af1c4014ae04c9e6c9131cbe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_148a8c3e4fad4cefa16a478a9758fdc5", + "placeholder": "​", + "style": "IPY_MODEL_92ef027f2cc940b5b09521328de550b0", + "value": "tokenizer.model.v3: 100%" + } + }, + "2332101edec848219c3b0c6026c2a722": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1a3d2764c7fc41dcb97489e84c28093e", + "max": 587404, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c7b6155f0f844c67b3a2b805570fd6f9", + "value": 587404 + } + }, + "a2638b3f10a24de99bb940dcd150ab53": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1eec374aa3414838a9b41e5db1fefd50", + "placeholder": "​", + "style": "IPY_MODEL_6fc8cf7aa81c4043878ac854b289dbe3", + "value": " 587k/587k [00:00<00:00, 5.39MB/s]" + } + }, + "1c78955b41ba4845931a250f16b753b5": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "148a8c3e4fad4cefa16a478a9758fdc5": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "92ef027f2cc940b5b09521328de550b0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1a3d2764c7fc41dcb97489e84c28093e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c7b6155f0f844c67b3a2b805570fd6f9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1eec374aa3414838a9b41e5db1fefd50": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6fc8cf7aa81c4043878ac854b289dbe3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b1aab1a3b5914048962a6d7d63401425": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_10fe81122e6442f28608766d90749790", + "IPY_MODEL_f7e72f0a87bc421b82a59ae9ad33a4cb", + "IPY_MODEL_2188b8f9491b4d3e8861e40e7c4f6a46" + ], + "layout": "IPY_MODEL_b3bf1880a5844f8c89096ced830fc954" + } + }, + "10fe81122e6442f28608766d90749790": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7e0495ffdeb74675847e5e4c2104cc34", + "placeholder": "​", + "style": "IPY_MODEL_3e0e828a21f24944b68a66eafb52f62b", + "value": "params.json: 100%" + } + }, + "f7e72f0a87bc421b82a59ae9ad33a4cb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ee07a8e2427c4fc9bd09b27ad11e968a", + "max": 202, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d35818e4f9454d26aa475826e08ea4f0", + "value": 202 + } + }, + "2188b8f9491b4d3e8861e40e7c4f6a46": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e3f7c4fea8494af2a473cf61adccf270", + "placeholder": "​", + "style": "IPY_MODEL_7e18402efdb34d708b7917964ac791de", + "value": " 202/202 [00:00<00:00, 12.7kB/s]" + } + }, + "b3bf1880a5844f8c89096ced830fc954": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7e0495ffdeb74675847e5e4c2104cc34": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3e0e828a21f24944b68a66eafb52f62b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ee07a8e2427c4fc9bd09b27ad11e968a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d35818e4f9454d26aa475826e08ea4f0": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e3f7c4fea8494af2a473cf61adccf270": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7e18402efdb34d708b7917964ac791de": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1442445cdf89487784d4a39919fec6bf": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_2ff8ebe8d132411585a05b852362c406", + "IPY_MODEL_ef719bb991714d91a365226c5a2ca9df", + "IPY_MODEL_1727f9b019e9477282d010e96b7dd4c3" + ], + "layout": "IPY_MODEL_f82b841d7e5b45229119bd3195e5b12f" + } + }, + "2ff8ebe8d132411585a05b852362c406": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_feb470b16b4249daaa19c1344d036f0a", + "placeholder": "​", + "style": "IPY_MODEL_2cc30eef6d7b46d283fcfd0e7abca6ea", + "value": "consolidated.safetensors: 100%" + } + }, + "ef719bb991714d91a365226c5a2ca9df": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_83b554bec0fd40dd9bd9e4601f2f98a3", + "max": 14496078512, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1c19998de61c4e2dad6647fdc4ca4358", + "value": 14496078512 + } + }, + "1727f9b019e9477282d010e96b7dd4c3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_25df0dd9481a4e0ba21d2f4f4ffdba2e", + "placeholder": "​", + "style": "IPY_MODEL_cbf620ae5196446c84528feaed64ae6a", + "value": " 14.5G/14.5G [01:49<00:00, 87.0MB/s]" + } + }, + "f82b841d7e5b45229119bd3195e5b12f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "feb470b16b4249daaa19c1344d036f0a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2cc30eef6d7b46d283fcfd0e7abca6ea": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "83b554bec0fd40dd9bd9e4601f2f98a3": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c19998de61c4e2dad6647fdc4ca4358": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "25df0dd9481a4e0ba21d2f4f4ffdba2e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cbf620ae5196446c84528feaed64ae6a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Getting starting fine-tuning Mistral 7B\n", + "\n", + "This notebook shows you a simple example of how to LoRA finetune Mistral 7B. You can can run this notebook in Google Colab with Pro + account with A100 and 40GB RAM.\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "\n", + "Check out `mistral-finetune` Github repo to learn more: https://github.com/mistralai/mistral-finetune/" + ], + "metadata": { + "id": "RyuOCYM92LJb" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Installation\n", + "\n", + "Clone the `mistral-finetune` repo:\n" + ], + "metadata": { + "id": "yxr8mv-17GfB" + } + }, + { + "cell_type": "code", + "source": [ + "!git clone git@github.com:mistralai/mistral-finetune.git" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TIj3IlIeVDIb", + "outputId": "7bfd4224-dac4-4570-cd68-c5a50cc41da7" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'mistral-finetune'...\n", + "remote: Enumerating objects: 2030, done.\u001b[K\n", + "remote: Counting objects: 100% (1747/1747), done.\u001b[K\n", + "remote: Compressing objects: 100% (513/513), done.\u001b[K\n", + "remote: Total 2030 (delta 1236), reused 1635 (delta 1171), pack-reused 283\u001b[K\n", + "Receiving objects: 100% (2030/2030), 697.88 KiB | 12.46 MiB/s, done.\n", + "Resolving deltas: 100% (1391/1391), done.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Install all required dependencies:" + ], + "metadata": { + "id": "mQPd_pGT7WiY" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -r requirements.txt" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KuTOGipl7BS7", + "outputId": "2ea1b6b9-f51d-4f44-839f-eaa665c20580" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/mistral-finetune\n", + "Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 1)) (0.6.0)\n", + "Requirement already satisfied: simple-parsing in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 2)) (0.1.5)\n", + "Requirement already satisfied: mistral-common in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 3)) (1.1.0)\n", + "Requirement already satisfied: torch==2.2 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 4)) (2.2.0)\n", + "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 5)) (0.17.0)\n", + "Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 6)) (2.15.2)\n", + "Requirement already satisfied: triton==2.2 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 7)) (2.2.0)\n", + "Requirement already satisfied: xformers==0.0.24 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 8)) (0.0.24)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (3.14.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (4.11.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (2023.6.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (2.19.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.2->-r requirements.txt (line 4)) (12.1.105)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from xformers==0.0.24->-r requirements.txt (line 8)) (1.25.2)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2->-r requirements.txt (line 4)) (12.5.40)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from fire->-r requirements.txt (line 1)) (1.16.0)\n", + "Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire->-r requirements.txt (line 1)) (2.4.0)\n", + "Requirement already satisfied: docstring-parser~=0.15 in /usr/local/lib/python3.10/dist-packages (from simple-parsing->-r requirements.txt (line 2)) (0.16)\n", + "Requirement already satisfied: jsonschema==4.21.1 in /usr/local/lib/python3.10/dist-packages (from mistral-common->-r requirements.txt (line 3)) (4.21.1)\n", + "Requirement already satisfied: pydantic==2.6.1 in /usr/local/lib/python3.10/dist-packages (from mistral-common->-r requirements.txt (line 3)) (2.6.1)\n", + "Requirement already satisfied: sentencepiece==0.1.99 in /usr/local/lib/python3.10/dist-packages (from mistral-common->-r requirements.txt (line 3)) (0.1.99)\n", + "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema==4.21.1->mistral-common->-r requirements.txt (line 3)) (23.2.0)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema==4.21.1->mistral-common->-r requirements.txt (line 3)) (2023.12.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema==4.21.1->mistral-common->-r requirements.txt (line 3)) (0.35.1)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema==4.21.1->mistral-common->-r requirements.txt (line 3)) (0.18.1)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic==2.6.1->mistral-common->-r requirements.txt (line 3)) (0.7.0)\n", + "Requirement already satisfied: pydantic-core==2.16.2 in /usr/local/lib/python3.10/dist-packages (from pydantic==2.6.1->mistral-common->-r requirements.txt (line 3)) (2.16.2)\n", + "Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (8.1.7)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (4.2.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (3.20.3)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (6.0.1)\n", + "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (2.31.0)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (2.3.1)\n", + "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (1.3.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb->-r requirements.txt (line 5)) (67.7.2)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (1.4.0)\n", + "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (1.64.0)\n", + "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (2.27.0)\n", + "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (1.2.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (3.6)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->-r requirements.txt (line 6)) (3.0.3)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb->-r requirements.txt (line 5)) (4.0.11)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (5.3.3)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (0.4.0)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (4.9)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard->-r requirements.txt (line 6)) (1.3.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 5)) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 5)) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 5)) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 5)) (2024.2.2)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard->-r requirements.txt (line 6)) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.2->-r requirements.txt (line 4)) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb->-r requirements.txt (line 5)) (5.0.1)\n", + "Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->-r requirements.txt (line 6)) (0.6.0)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard->-r requirements.txt (line 6)) (3.2.2)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Model download" + ], + "metadata": { + "id": "LgdIAi257jLo" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cdl_R5baUyha", + "outputId": "3a41226e-4ddf-44b0-b8ac-58d23d1a3199" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2024-05-24 15:56:01-- https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar\n", + "Resolving models.mistralcdn.com (models.mistralcdn.com)... 172.67.70.68, 104.26.7.117, 104.26.6.117, ...\n", + "Connecting to models.mistralcdn.com (models.mistralcdn.com)|172.67.70.68|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 14496675840 (14G) [application/x-tar]\n", + "Saving to: ‘mistral-7B-v0.3.tar.1’\n", + "\n", + "mistral-7B-v0.3.tar 100%[===================>] 13.50G 45.7MB/s in 6m 24s \n", + "\n", + "2024-05-24 16:02:25 (36.0 MB/s) - ‘mistral-7B-v0.3.tar.1’ saved [14496675840/14496675840]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!DIR=mistral_models && mkdir -p $DIR && tar -xf mistral-7B-v0.3.tar -C $DIR" + ], + "metadata": { + "id": "IgJWR-fReilz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Alternatively, you can download the model from Hugging Face\n", + "\n", + "# !pip install huggingface_hub\n", + "# from huggingface_hub import snapshot_download\n", + "# from pathlib import Path\n", + "\n", + "# mistral_models_path = Path.home().joinpath('mistral_models', '7B-v0.3')\n", + "# mistral_models_path.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# snapshot_download(repo_id=\"mistralai/Mistral-7B-v0.3\", allow_patterns=[\"params.json\", \"consolidated.safetensors\", \"tokenizer.model.v3\"], local_dir=mistral_models_path)\n", + "\n", + "#! cp -r /root/mistral_models/7B-v0.3 /content/mistral_models\n", + "#! rm -r /root/mistral_models/7B-v0.3" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 163, + "referenced_widgets": [ + "c06bbba05e10462d993f3e7e6f932cf1", + "6f734c35284341d891a44694ddc55b2c", + "193be53200ab436a967f1ea4807053e2", + "85a2415f14284237875b349b4c414e21", + "3ab931b2fcc0493ca71923ebc37127c7", + "f49472d0536f4011be17902f9e827807", + "2f341cb76f254d0da913faec6a82f762", + "24ebf383723f4de494f9808b41222751", + "0c17aa90672046c9bd2f293b1a998b46", + "af89033247b34da2a5cded73b0beade2", + "0ea5f5e6ab26484ab22dcf5576f796d1", + "946ce9afeddb4da5a36e81e5ada9d957", + "95782b84af1c4014ae04c9e6c9131cbe", + "2332101edec848219c3b0c6026c2a722", + "a2638b3f10a24de99bb940dcd150ab53", + "1c78955b41ba4845931a250f16b753b5", + "148a8c3e4fad4cefa16a478a9758fdc5", + "92ef027f2cc940b5b09521328de550b0", + "1a3d2764c7fc41dcb97489e84c28093e", + "c7b6155f0f844c67b3a2b805570fd6f9", + "1eec374aa3414838a9b41e5db1fefd50", + "6fc8cf7aa81c4043878ac854b289dbe3", + "b1aab1a3b5914048962a6d7d63401425", + "10fe81122e6442f28608766d90749790", + "f7e72f0a87bc421b82a59ae9ad33a4cb", + "2188b8f9491b4d3e8861e40e7c4f6a46", + "b3bf1880a5844f8c89096ced830fc954", + "7e0495ffdeb74675847e5e4c2104cc34", + "3e0e828a21f24944b68a66eafb52f62b", + "ee07a8e2427c4fc9bd09b27ad11e968a", + "d35818e4f9454d26aa475826e08ea4f0", + "e3f7c4fea8494af2a473cf61adccf270", + "7e18402efdb34d708b7917964ac791de", + "1442445cdf89487784d4a39919fec6bf", + "2ff8ebe8d132411585a05b852362c406", + "ef719bb991714d91a365226c5a2ca9df", + "1727f9b019e9477282d010e96b7dd4c3", + "f82b841d7e5b45229119bd3195e5b12f", + "feb470b16b4249daaa19c1344d036f0a", + "2cc30eef6d7b46d283fcfd0e7abca6ea", + "83b554bec0fd40dd9bd9e4601f2f98a3", + "1c19998de61c4e2dad6647fdc4ca4358", + "25df0dd9481a4e0ba21d2f4f4ffdba2e", + "cbf620ae5196446c84528feaed64ae6a" + ] + }, + "id": "qgjAADBFHB0S", + "outputId": "6dd98910-36fd-4dc1-c5b8-77bb4c104e05" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Fetching 3 files: 0%| | 0/3 [00:002024-05-24 15:58:57\n", + "2024-05-24 15:24:33 (UTC) - 0:00:56 - train - INFO - step: 000002 - done (%): 2.0 - loss: 0.905 - lr: 1.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3676.2 - avg_words_per_second: 3372.9 - ETA: >2024-05-24 15:56:17\n", + "2024-05-24 15:24:51 (UTC) - 0:01:14 - train - INFO - step: 000003 - done (%): 3.0 - loss: 0.912 - lr: 5.2e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3671.2 - avg_words_per_second: 3466.8 - ETA: >2024-05-24 15:55:24\n", + "2024-05-24 15:25:09 (UTC) - 0:01:32 - train - INFO - step: 000004 - done (%): 4.0 - loss: 0.884 - lr: 8.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3642.0 - avg_words_per_second: 3509.0 - ETA: >2024-05-24 15:55:02\n", + "2024-05-24 15:25:27 (UTC) - 0:01:50 - train - INFO - step: 000005 - done (%): 5.0 - loss: 0.836 - lr: 1.0e-04 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3667.2 - avg_words_per_second: 3539.5 - ETA: >2024-05-24 15:54:46\n", + "2024-05-24 15:25:45 (UTC) - 0:02:07 - train - INFO - step: 000006 - done (%): 6.0 - loss: 0.858 - lr: 1.0e-04 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3652.4 - avg_words_per_second: 3557.9 - ETA: >2024-05-24 15:54:36\n", + "2024-05-24 15:26:02 (UTC) - 0:02:25 - train - INFO - step: 000007 - done (%): 7.0 - loss: 0.868 - lr: 1.0e-04 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3664.7 - avg_words_per_second: 3572.7 - ETA: >2024-05-24 15:54:28\n", + "2024-05-24 15:26:20 (UTC) - 0:02:43 - train - INFO - step: 000008 - done (%): 8.0 - loss: 0.869 - lr: 1.0e-04 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3665.3 - avg_words_per_second: 3584.1 - ETA: >2024-05-24 15:54:23\n", + "2024-05-24 15:26:38 (UTC) - 0:03:01 - train - INFO - step: 000009 - done (%): 9.0 - loss: 0.805 - lr: 1.0e-04 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3656.4 - avg_words_per_second: 3591.9 - ETA: >2024-05-24 15:54:19\n", + "2024-05-24 15:26:56 (UTC) - 0:03:19 - train - INFO - step: 000010 - done (%): 10.0 - loss: 0.898 - lr: 9.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3631.3 - avg_words_per_second: 3595.8 - ETA: >2024-05-24 15:54:17\n", + "2024-05-24 15:27:14 (UTC) - 0:03:37 - train - INFO - step: 000011 - done (%): 11.0 - loss: 0.744 - lr: 9.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3624.0 - avg_words_per_second: 3598.4 - ETA: >2024-05-24 15:54:15\n", + "2024-05-24 15:27:32 (UTC) - 0:03:55 - train - INFO - step: 000012 - done (%): 12.0 - loss: 0.837 - lr: 9.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3636.9 - avg_words_per_second: 3601.6 - ETA: >2024-05-24 15:54:14\n", + "2024-05-24 15:27:50 (UTC) - 0:04:13 - train - INFO - step: 000013 - done (%): 13.0 - loss: 0.846 - lr: 9.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3655.9 - avg_words_per_second: 3605.7 - ETA: >2024-05-24 15:54:12\n", + "2024-05-24 15:28:08 (UTC) - 0:04:31 - train - INFO - step: 000014 - done (%): 14.0 - loss: 0.853 - lr: 9.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3638.7 - avg_words_per_second: 3608.0 - ETA: >2024-05-24 15:54:10\n", + "2024-05-24 15:28:26 (UTC) - 0:04:49 - train - INFO - step: 000015 - done (%): 15.0 - loss: 0.912 - lr: 9.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3645.2 - avg_words_per_second: 3610.5 - ETA: >2024-05-24 15:54:09\n", + "2024-05-24 15:28:44 (UTC) - 0:05:07 - train - INFO - step: 000016 - done (%): 16.0 - loss: 0.815 - lr: 9.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3664.3 - avg_words_per_second: 3613.8 - ETA: >2024-05-24 15:54:08\n", + "2024-05-24 15:29:02 (UTC) - 0:05:25 - train - INFO - step: 000017 - done (%): 17.0 - loss: 0.822 - lr: 9.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3668.9 - avg_words_per_second: 3617.0 - ETA: >2024-05-24 15:54:06\n", + "2024-05-24 15:29:20 (UTC) - 0:05:43 - train - INFO - step: 000018 - done (%): 18.0 - loss: 0.842 - lr: 9.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3649.9 - avg_words_per_second: 3618.8 - ETA: >2024-05-24 15:54:05\n", + "2024-05-24 15:29:38 (UTC) - 0:06:01 - train - INFO - step: 000019 - done (%): 19.0 - loss: 0.802 - lr: 9.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3665.8 - avg_words_per_second: 3621.2 - ETA: >2024-05-24 15:54:04\n", + "2024-05-24 15:29:56 (UTC) - 0:06:19 - train - INFO - step: 000020 - done (%): 20.0 - loss: 0.867 - lr: 9.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3653.4 - avg_words_per_second: 3622.8 - ETA: >2024-05-24 15:54:03\n", + "2024-05-24 15:30:14 (UTC) - 0:06:37 - train - INFO - step: 000021 - done (%): 21.0 - loss: 0.826 - lr: 9.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3655.3 - avg_words_per_second: 3624.4 - ETA: >2024-05-24 15:54:02\n", + "2024-05-24 15:30:32 (UTC) - 0:06:55 - train - INFO - step: 000022 - done (%): 22.0 - loss: 0.809 - lr: 9.2e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3621.2 - avg_words_per_second: 3624.2 - ETA: >2024-05-24 15:54:02\n", + "2024-05-24 15:30:50 (UTC) - 0:07:13 - train - INFO - step: 000023 - done (%): 23.0 - loss: 0.837 - lr: 9.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3646.8 - avg_words_per_second: 3625.2 - ETA: >2024-05-24 15:54:02\n", + "2024-05-24 15:31:08 (UTC) - 0:07:31 - train - INFO - step: 000024 - done (%): 24.0 - loss: 0.851 - lr: 9.0e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3666.2 - avg_words_per_second: 3626.9 - ETA: >2024-05-24 15:54:01\n", + "2024-05-24 15:31:26 (UTC) - 0:07:49 - train - INFO - step: 000025 - done (%): 25.0 - loss: 0.856 - lr: 8.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3649.7 - avg_words_per_second: 3627.8 - ETA: >2024-05-24 15:54:01\n", + "2024-05-24 15:31:44 (UTC) - 0:08:07 - train - INFO - step: 000026 - done (%): 26.0 - loss: 0.781 - lr: 8.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3641.7 - avg_words_per_second: 3628.3 - ETA: >2024-05-24 15:54:00\n", + "2024-05-24 15:32:02 (UTC) - 0:08:24 - train - INFO - step: 000027 - done (%): 27.0 - loss: 0.845 - lr: 8.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3667.1 - avg_words_per_second: 3629.7 - ETA: >2024-05-24 15:54:00\n", + "2024-05-24 15:32:19 (UTC) - 0:08:42 - train - INFO - step: 000028 - done (%): 28.0 - loss: 0.831 - lr: 8.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3659.9 - avg_words_per_second: 3630.8 - ETA: >2024-05-24 15:53:59\n", + "2024-05-24 15:32:37 (UTC) - 0:09:00 - train - INFO - step: 000029 - done (%): 29.0 - loss: 0.806 - lr: 8.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3661.9 - avg_words_per_second: 3631.9 - ETA: >2024-05-24 15:53:59\n", + "2024-05-24 15:32:55 (UTC) - 0:09:18 - train - INFO - step: 000030 - done (%): 30.0 - loss: 0.898 - lr: 8.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3654.0 - avg_words_per_second: 3632.6 - ETA: >2024-05-24 15:53:58\n", + "2024-05-24 15:33:13 (UTC) - 0:09:36 - train - INFO - step: 000031 - done (%): 31.0 - loss: 0.817 - lr: 8.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3672.1 - avg_words_per_second: 3633.9 - ETA: >2024-05-24 15:53:58\n", + "2024-05-24 15:33:31 (UTC) - 0:09:54 - train - INFO - step: 000032 - done (%): 32.0 - loss: 0.825 - lr: 8.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3650.5 - avg_words_per_second: 3634.4 - ETA: >2024-05-24 15:53:57\n", + "2024-05-24 15:33:49 (UTC) - 0:10:12 - train - INFO - step: 000033 - done (%): 33.0 - loss: 0.845 - lr: 8.0e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3665.2 - avg_words_per_second: 3635.3 - ETA: >2024-05-24 15:53:57\n", + "2024-05-24 15:34:07 (UTC) - 0:10:30 - train - INFO - step: 000034 - done (%): 34.0 - loss: 0.808 - lr: 7.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3654.7 - avg_words_per_second: 3635.9 - ETA: >2024-05-24 15:53:57\n", + "2024-05-24 15:34:25 (UTC) - 0:10:48 - train - INFO - step: 000035 - done (%): 35.0 - loss: 0.853 - lr: 7.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3635.9 - avg_words_per_second: 3635.9 - ETA: >2024-05-24 15:53:57\n", + "2024-05-24 15:34:43 (UTC) - 0:11:06 - train - INFO - step: 000036 - done (%): 36.0 - loss: 0.813 - lr: 7.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3640.1 - avg_words_per_second: 3636.0 - ETA: >2024-05-24 15:53:56\n", + "2024-05-24 15:35:01 (UTC) - 0:11:24 - train - INFO - step: 000037 - done (%): 37.0 - loss: 0.801 - lr: 7.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3665.5 - avg_words_per_second: 3636.8 - ETA: >2024-05-24 15:53:56\n", + "2024-05-24 15:35:19 (UTC) - 0:11:42 - train - INFO - step: 000038 - done (%): 38.0 - loss: 0.744 - lr: 7.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3657.0 - avg_words_per_second: 3637.3 - ETA: >2024-05-24 15:53:56\n", + "2024-05-24 15:35:37 (UTC) - 0:12:00 - train - INFO - step: 000039 - done (%): 39.0 - loss: 0.816 - lr: 7.2e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3637.2 - avg_words_per_second: 3637.3 - ETA: >2024-05-24 15:53:56\n", + "2024-05-24 15:35:55 (UTC) - 0:12:18 - train - INFO - step: 000040 - done (%): 40.0 - loss: 0.786 - lr: 7.0e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3636.5 - avg_words_per_second: 3637.3 - ETA: >2024-05-24 15:53:56\n", + "2024-05-24 15:36:13 (UTC) - 0:12:36 - train - INFO - step: 000041 - done (%): 41.0 - loss: 0.805 - lr: 6.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3656.9 - avg_words_per_second: 3637.8 - ETA: >2024-05-24 15:53:56\n", + "2024-05-24 15:36:31 (UTC) - 0:12:54 - train - INFO - step: 000042 - done (%): 42.0 - loss: 0.845 - lr: 6.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3660.8 - avg_words_per_second: 3638.3 - ETA: >2024-05-24 15:53:55\n", + "2024-05-24 15:36:49 (UTC) - 0:13:11 - train - INFO - step: 000043 - done (%): 43.0 - loss: 0.864 - lr: 6.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3654.3 - avg_words_per_second: 3638.7 - ETA: >2024-05-24 15:53:55\n", + "2024-05-24 15:37:06 (UTC) - 0:13:29 - train - INFO - step: 000044 - done (%): 44.0 - loss: 0.862 - lr: 6.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3660.0 - avg_words_per_second: 3639.2 - ETA: >2024-05-24 15:53:55\n", + "2024-05-24 15:37:24 (UTC) - 0:13:47 - train - INFO - step: 000045 - done (%): 45.0 - loss: 0.862 - lr: 6.2e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3664.6 - avg_words_per_second: 3639.7 - ETA: >2024-05-24 15:53:55\n", + "2024-05-24 15:37:42 (UTC) - 0:14:05 - train - INFO - step: 000046 - done (%): 46.0 - loss: 0.828 - lr: 6.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3639.4 - avg_words_per_second: 3639.7 - ETA: >2024-05-24 15:53:55\n", + "2024-05-24 15:38:00 (UTC) - 0:14:23 - train - INFO - step: 000047 - done (%): 47.0 - loss: 0.812 - lr: 5.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3671.6 - avg_words_per_second: 3640.4 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:38:18 (UTC) - 0:14:41 - train - INFO - step: 000048 - done (%): 48.0 - loss: 0.818 - lr: 5.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3662.6 - avg_words_per_second: 3640.9 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:38:36 (UTC) - 0:14:59 - train - INFO - step: 000049 - done (%): 49.0 - loss: 0.818 - lr: 5.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3654.1 - avg_words_per_second: 3641.1 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:38:54 (UTC) - 0:15:17 - train - INFO - step: 000050 - done (%): 50.0 - loss: 0.888 - lr: 5.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3638.6 - avg_words_per_second: 3641.1 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:39:12 (UTC) - 0:15:35 - train - INFO - step: 000051 - done (%): 51.0 - loss: 0.777 - lr: 5.2e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3637.1 - avg_words_per_second: 3641.0 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:39:30 (UTC) - 0:15:53 - train - INFO - step: 000052 - done (%): 52.0 - loss: 0.804 - lr: 5.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3656.1 - avg_words_per_second: 3641.3 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:39:48 (UTC) - 0:16:11 - train - INFO - step: 000053 - done (%): 53.0 - loss: 0.799 - lr: 4.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3656.0 - avg_words_per_second: 3641.6 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:40:06 (UTC) - 0:16:29 - train - INFO - step: 000054 - done (%): 54.0 - loss: 0.804 - lr: 4.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3661.3 - avg_words_per_second: 3641.9 - ETA: >2024-05-24 15:53:54\n", + "2024-05-24 15:40:24 (UTC) - 0:16:47 - train - INFO - step: 000055 - done (%): 55.0 - loss: 0.854 - lr: 4.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3671.5 - avg_words_per_second: 3642.5 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:40:42 (UTC) - 0:17:04 - train - INFO - step: 000056 - done (%): 56.0 - loss: 0.819 - lr: 4.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3666.0 - avg_words_per_second: 3642.9 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:41:00 (UTC) - 0:17:22 - train - INFO - step: 000057 - done (%): 57.0 - loss: 0.844 - lr: 4.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3647.3 - avg_words_per_second: 3643.0 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:41:17 (UTC) - 0:17:40 - train - INFO - step: 000058 - done (%): 58.0 - loss: 0.840 - lr: 4.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3658.4 - avg_words_per_second: 3643.2 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:41:35 (UTC) - 0:17:58 - train - INFO - step: 000059 - done (%): 59.0 - loss: 0.836 - lr: 3.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3670.1 - avg_words_per_second: 3643.7 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:41:53 (UTC) - 0:18:16 - train - INFO - step: 000060 - done (%): 60.0 - loss: 0.852 - lr: 3.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3669.1 - avg_words_per_second: 3644.1 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:42:11 (UTC) - 0:18:34 - train - INFO - step: 000061 - done (%): 61.0 - loss: 0.837 - lr: 3.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3642.9 - avg_words_per_second: 3644.1 - ETA: >2024-05-24 15:53:53\n", + "2024-05-24 15:42:29 (UTC) - 0:18:52 - train - INFO - step: 000062 - done (%): 62.0 - loss: 0.839 - lr: 3.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3652.3 - avg_words_per_second: 3644.2 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:42:47 (UTC) - 0:19:10 - train - INFO - step: 000063 - done (%): 63.0 - loss: 0.813 - lr: 3.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3662.4 - avg_words_per_second: 3644.5 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:43:05 (UTC) - 0:19:28 - train - INFO - step: 000064 - done (%): 64.0 - loss: 0.784 - lr: 3.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3650.1 - avg_words_per_second: 3644.6 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:43:23 (UTC) - 0:19:46 - train - INFO - step: 000065 - done (%): 65.0 - loss: 0.796 - lr: 3.0e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3654.6 - avg_words_per_second: 3644.7 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:43:41 (UTC) - 0:20:04 - train - INFO - step: 000066 - done (%): 66.0 - loss: 0.787 - lr: 2.8e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3656.1 - avg_words_per_second: 3644.9 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:43:59 (UTC) - 0:20:22 - train - INFO - step: 000067 - done (%): 67.0 - loss: 0.903 - lr: 2.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3670.1 - avg_words_per_second: 3645.3 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:44:17 (UTC) - 0:20:40 - train - INFO - step: 000068 - done (%): 68.0 - loss: 0.783 - lr: 2.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3632.2 - avg_words_per_second: 3645.1 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:44:35 (UTC) - 0:20:57 - train - INFO - step: 000069 - done (%): 69.0 - loss: 0.889 - lr: 2.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3665.0 - avg_words_per_second: 3645.4 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:44:53 (UTC) - 0:21:15 - train - INFO - step: 000070 - done (%): 70.0 - loss: 0.810 - lr: 2.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3644.1 - avg_words_per_second: 3645.4 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:45:10 (UTC) - 0:21:33 - train - INFO - step: 000071 - done (%): 71.0 - loss: 0.797 - lr: 2.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3672.3 - avg_words_per_second: 3645.7 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:45:28 (UTC) - 0:21:51 - train - INFO - step: 000072 - done (%): 72.0 - loss: 0.928 - lr: 2.0e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3631.2 - avg_words_per_second: 3645.5 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:45:46 (UTC) - 0:22:09 - train - INFO - step: 000073 - done (%): 73.0 - loss: 0.811 - lr: 1.9e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3654.7 - avg_words_per_second: 3645.7 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:46:04 (UTC) - 0:22:27 - train - INFO - step: 000074 - done (%): 74.0 - loss: 0.828 - lr: 1.7e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3645.9 - avg_words_per_second: 3645.7 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:46:22 (UTC) - 0:22:45 - train - INFO - step: 000075 - done (%): 75.0 - loss: 0.908 - lr: 1.6e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3663.3 - avg_words_per_second: 3645.9 - ETA: >2024-05-24 15:53:52\n", + "2024-05-24 15:46:40 (UTC) - 0:23:03 - train - INFO - step: 000076 - done (%): 76.0 - loss: 0.880 - lr: 1.5e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3676.9 - avg_words_per_second: 3646.3 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:46:58 (UTC) - 0:23:21 - train - INFO - step: 000077 - done (%): 77.0 - loss: 0.819 - lr: 1.4e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3638.8 - avg_words_per_second: 3646.2 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:47:16 (UTC) - 0:23:39 - train - INFO - step: 000078 - done (%): 78.0 - loss: 0.867 - lr: 1.3e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3661.8 - avg_words_per_second: 3646.4 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:47:34 (UTC) - 0:23:57 - train - INFO - step: 000079 - done (%): 79.0 - loss: 0.913 - lr: 1.2e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3661.1 - avg_words_per_second: 3646.6 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:47:52 (UTC) - 0:24:15 - train - INFO - step: 000080 - done (%): 80.0 - loss: 0.826 - lr: 1.1e-05 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3660.3 - avg_words_per_second: 3646.8 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:48:10 (UTC) - 0:24:33 - train - INFO - step: 000081 - done (%): 81.0 - loss: 0.835 - lr: 9.5e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3671.3 - avg_words_per_second: 3647.1 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:48:28 (UTC) - 0:24:51 - train - INFO - step: 000082 - done (%): 82.0 - loss: 0.853 - lr: 8.6e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3652.9 - avg_words_per_second: 3647.1 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:48:46 (UTC) - 0:25:09 - train - INFO - step: 000083 - done (%): 83.0 - loss: 0.772 - lr: 7.7e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3638.3 - avg_words_per_second: 3647.0 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:49:04 (UTC) - 0:25:27 - train - INFO - step: 000084 - done (%): 84.0 - loss: 0.813 - lr: 6.8e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3640.2 - avg_words_per_second: 3646.9 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:49:22 (UTC) - 0:25:45 - train - INFO - step: 000085 - done (%): 85.0 - loss: 0.836 - lr: 6.0e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3622.0 - avg_words_per_second: 3646.6 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:49:40 (UTC) - 0:26:03 - train - INFO - step: 000086 - done (%): 86.0 - loss: 0.866 - lr: 5.3e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3656.0 - avg_words_per_second: 3646.8 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:49:57 (UTC) - 0:26:20 - train - INFO - step: 000087 - done (%): 87.0 - loss: 0.844 - lr: 4.6e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3675.7 - avg_words_per_second: 3647.1 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:50:15 (UTC) - 0:26:38 - train - INFO - step: 000088 - done (%): 88.0 - loss: 0.777 - lr: 3.9e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3648.1 - avg_words_per_second: 3647.1 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:50:33 (UTC) - 0:26:56 - train - INFO - step: 000089 - done (%): 89.0 - loss: 0.813 - lr: 3.3e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3640.2 - avg_words_per_second: 3647.0 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:50:51 (UTC) - 0:27:14 - train - INFO - step: 000090 - done (%): 90.0 - loss: 0.841 - lr: 2.7e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3666.4 - avg_words_per_second: 3647.2 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:51:09 (UTC) - 0:27:32 - train - INFO - step: 000091 - done (%): 91.0 - loss: 0.807 - lr: 2.2e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3633.6 - avg_words_per_second: 3647.1 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:51:27 (UTC) - 0:27:50 - train - INFO - step: 000092 - done (%): 92.0 - loss: 0.807 - lr: 1.7e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3647.9 - avg_words_per_second: 3647.1 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:51:45 (UTC) - 0:28:08 - train - INFO - step: 000093 - done (%): 93.0 - loss: 0.827 - lr: 1.3e-06 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3666.5 - avg_words_per_second: 3647.3 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:52:03 (UTC) - 0:28:26 - train - INFO - step: 000094 - done (%): 94.0 - loss: 0.816 - lr: 9.8e-07 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3666.8 - avg_words_per_second: 3647.5 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:52:21 (UTC) - 0:28:44 - train - INFO - step: 000095 - done (%): 95.0 - loss: 0.824 - lr: 6.8e-07 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3672.7 - avg_words_per_second: 3647.8 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:52:39 (UTC) - 0:29:02 - train - INFO - step: 000096 - done (%): 96.0 - loss: 0.794 - lr: 4.4e-07 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3649.7 - avg_words_per_second: 3647.8 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:52:57 (UTC) - 0:29:20 - train - INFO - step: 000097 - done (%): 97.0 - loss: 0.884 - lr: 2.5e-07 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3672.3 - avg_words_per_second: 3648.0 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:53:15 (UTC) - 0:29:38 - train - INFO - step: 000098 - done (%): 98.0 - loss: 0.826 - lr: 1.1e-07 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3660.1 - avg_words_per_second: 3648.2 - ETA: >2024-05-24 15:53:51\n", + "2024-05-24 15:53:33 (UTC) - 0:29:55 - train - INFO - step: 000099 - done (%): 99.0 - loss: 0.771 - lr: 2.8e-08 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 3664.2 - avg_words_per_second: 3648.3 - ETA: >2024-05-24 15:53:50\n", + "2024-05-24 15:53:50 (UTC) - 0:30:13 - eval - INFO - Start eval...\n", + "2024-05-24 15:55:45 (UTC) - 0:32:08 - eval - INFO - Eval finished!\n", + "2024-05-24 15:55:45 (UTC) - 0:32:08 - train - INFO - step: 000100 - eval_perplexity: 1.779 - eval_loss: 0.831 - train_loss: 0.761\n", + "2024-05-24 15:55:45 (UTC) - 0:32:08 - train - INFO - step: 000100 - done (%): 100.0 - loss: 0.761 - lr: 4.0e-10 - peak_alloc_mem (GB): 22.2 - alloc_mem (GB): 17.1 - words_per_second: 495.3 - avg_words_per_second: 3430.0 - ETA: >2024-05-24 15:55:45\n", + "2024-05-24 15:55:45 (UTC) - 0:32:08 - checkpointing - INFO - Dumping checkpoint in /content/test_ultra/checkpoints/checkpoint_000100/consolidated using tmp name: tmp.consolidated\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - checkpointing - INFO - Done dumping checkpoint in /content/test_ultra/checkpoints/checkpoint_000100/consolidated for step: 100\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - checkpointing - INFO - Done deleting checkpoints \n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - checkpointing - INFO - Done!\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - train - INFO - done!\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - utils - INFO - Closing: eval_logger\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - utils - INFO - Closed: eval_logger\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - utils - INFO - Closing: metrics_logger\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - utils - INFO - Closed: metrics_logger\n", + "2024-05-24 15:55:46 (UTC) - 0:32:09 - train - INFO - Closed everything!\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "sLAiR9-032ft" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/extend_model_vocab.py b/utils/extend_model_vocab.py new file mode 100644 index 0000000..b22caf9 --- /dev/null +++ b/utils/extend_model_vocab.py @@ -0,0 +1,144 @@ +import argparse +import json +import math +import os +from pathlib import Path + +import torch +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer + +from model.args import ModelArgs + +FIRST_PIECE_ID = 3 +OLD_VOCAB_SIZE = 32000 +NEW_VOCAB_SIZE = 32768 + + +def extend_model(original_model: Path, extended_model: Path): + original_ckpt = torch.load(str(original_model / "consolidated.00.pth"), mmap=True) + model_args = ModelArgs.load(str(original_model / "params.json")) + + original_vocab_size = model_args.vocab_size + assert ( + original_vocab_size == OLD_VOCAB_SIZE + ), f"Original vocab size {original_vocab_size} is not equal to 32000. Can only extend models with vocab_size of 32000" + + if not extended_model.exists(): + os.makedirs(extended_model, exist_ok=True) + print(f"Created empty directory {extended_model}.") + + assert not list( + extended_model.iterdir() + ), f"Make sure {extended_model} is empty" + + # Load and check tokenizers + mistral_tokenizer = MistralTokenizer.v3() + tokenizer: SentencePieceTokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer + + new_vocab_size = tokenizer.n_words + assert ( + new_vocab_size == 32768 + ), f"New Tokenzier has vocab_size: {new_vocab_size} but has to be equal to 32768. Make sure to pass a v2 or v3 tokenizer file" + + vocabulary_delta = new_vocab_size - original_vocab_size + + # Check that 0...FIRST_PIECE_ID-1 are UNK + control characters and FIRST_PIECE_ID is the first piece + assert tokenizer._model.id_to_piece(vocabulary_delta + FIRST_PIECE_ID) == "<0x00>" + assert tokenizer._model.id_to_piece(FIRST_PIECE_ID - 1) == "" + + assert isinstance(tokenizer, SentencePieceTokenizer) + + original_embeddings = original_ckpt["tok_embeddings.weight"] + + assert ( + original_vocab_size == original_embeddings.shape[0] + ), f"Original vocab size {original_vocab_size} is not equal to original embeddings shape {original_embeddings.shape[0]}." + + dim = original_embeddings.shape[1] + + # Extend embeddings + extended_embeddings = torch.zeros( + tokenizer.n_words, dim, dtype=original_embeddings.dtype + ) + extended_embeddings[:original_vocab_size] = original_embeddings + extended_embeddings[:FIRST_PIECE_ID] = original_embeddings[:FIRST_PIECE_ID] + extended_embeddings[FIRST_PIECE_ID + vocabulary_delta :] = original_embeddings[ + FIRST_PIECE_ID: + ] + + # randomly initialize new tokens + extended_tokens = torch.empty( + vocabulary_delta, dim, dtype=original_embeddings.dtype + ) + torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) + + extended_embeddings[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( + extended_tokens + ) + + # Extend output + original_output = original_ckpt["output.weight"] + assert ( + original_output.shape[0] == original_vocab_size + ), f"Original output shape {original_output.shape[0]} is not equal to {original_vocab_size}." + assert ( + original_output.shape[1] == dim + ), f"Original output dim {original_output.shape[1]} is not equal to embedding dim {dim}." + + assert ( + original_output.dtype == original_embeddings.dtype + ), f"Original output and embeddings have different dtypes: {original_output.dtype} vs {original_embeddings.dtype}." + + extended_output = torch.zeros(tokenizer.n_words, dim, dtype=original_output.dtype) + extended_output[:FIRST_PIECE_ID] = original_output[:FIRST_PIECE_ID] + extended_output[FIRST_PIECE_ID + vocabulary_delta :] = original_output[ + FIRST_PIECE_ID: + ] + + # randomly initialize new tokens + extended_tokens = torch.empty(vocabulary_delta, dim, dtype=original_output.dtype) + torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) + + extended_output[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( + extended_tokens + ) + + original_ckpt["tok_embeddings.weight"] = extended_embeddings + original_ckpt["output.weight"] = extended_output + + new_ckpt_path = extended_model / "consolidated.00.pth" + print(f"Exporting extended model to {extended_model} ...") + torch.save(original_ckpt, new_ckpt_path) + + params_path = extended_model / "params.json" + with open(params_path, "w") as f: + model_dict = model_args.to_dict() + del model_dict["lora"] + if model_dict["moe"] is None: + del model_dict["moe"] + model_dict["vocab_size"] = new_vocab_size + + f.write(json.dumps(model_dict, indent=4)) + + +def main(): + parser = argparse.ArgumentParser( + description="Extend a model using the specified original model, extended model, and tokenizer paths." + ) + parser.add_argument( + "--original_model_ckpt", type=Path, help="Path to the original model folder." + ) + parser.add_argument( + "--extended_model_ckpt", type=Path, help="Path to the extended model file." + ) + args = parser.parse_args() + + extend_model( + original_model=args.original_model_ckpt, + extended_model=args.extended_model_ckpt, + ) + + +if __name__ == "__main__": + main() diff --git a/utils/merge_lora.py b/utils/merge_lora.py new file mode 100644 index 0000000..4fce690 --- /dev/null +++ b/utils/merge_lora.py @@ -0,0 +1,90 @@ +import argparse +from typing import Dict, Optional + +import safetensors.torch +import torch +import tqdm + + +def merge_checkpoints( + model_checkpoint: Dict[str, torch.Tensor], + lora_checkpoint: Dict[str, torch.Tensor], + scaling: float, + save_dtype: Optional[torch.dtype] = None, +): + save_dtype = save_dtype or next(iter(lora_checkpoint.values())).dtype + print(f"Merging to {save_dtype} precision...") + + keys_to_update = [ + key for key in lora_checkpoint.keys() if "norm" in key or "lora_A" in key + ] + assert any( + "lora_A" in k or "lora_B" in k for k in keys_to_update + ), "No `lora` keys found in your checkpoint. Check that `lora_ckpt` is correct." + + for key in tqdm.tqdm(keys_to_update): + if "norm" in key: + model_checkpoint[key] = lora_checkpoint[key].to("cpu") + else: + weight_name = key.replace("lora_A.weight", "weight") + + lora_A_weight = lora_checkpoint[key].to("cuda") + lora_B_weight = lora_checkpoint[key.replace("lora_A", "lora_B")].to("cuda") + + weight = lora_B_weight.mm(lora_A_weight) * scaling + weight += model_checkpoint[weight_name].to("cuda") + weight = weight.to(save_dtype) + + model_checkpoint[weight_name] = weight.to("cpu") + + # cast all tensors to save dtype + for key in tqdm.tqdm(model_checkpoint.keys()): + model_checkpoint[key] = model_checkpoint[key].to(save_dtype) + + +def load(filename: str): + if filename.endswith(".safetensors"): + return safetensors.torch.load_file(filename) + else: + return torch.load(filename) + + +def main(args): + model_checkpoint = load(args.initial_model_ckpt) + lora_checkpoint = load(args.lora_ckpt) + + merge_checkpoints(model_checkpoint, lora_checkpoint, args.scaling) + + safetensors.torch.save_file(model_checkpoint, args.dump_ckpt) + + print(f"Merged checkpoint saved to {args.dump_ckpt}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Merge a LoRA checkpoint into a model checkpoint." + ) + parser.add_argument( + "--initial_model_ckpt", + type=str, + required=True, + help="Path to the model checkpoint.", + ) + parser.add_argument( + "--lora_ckpt", type=str, required=True, help="Path to the LoRA checkpoint." + ) + parser.add_argument( + "--dump_ckpt", + type=str, + required=True, + help="Path to save the merged checkpoint.", + ) + parser.add_argument( + "--scaling", + type=float, + default=2.0, + help="Scaling factor for the LoRA checkpoint. Default is 2.0.", + ) + + args = parser.parse_args() + main(args) diff --git a/utils/reformat_data.py b/utils/reformat_data.py new file mode 100755 index 0000000..465abfa --- /dev/null +++ b/utils/reformat_data.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import random +import string + + +def reformat_jsonl(input_file): + output_file = input_file + ".tmp" + + content_keys = ["content", "text"] + with open(input_file, "r") as infile, open(output_file, "w") as outfile: + for idx, line in enumerate(infile): + data = json.loads(line) + if "functions" in data: + data["tools"] = [{"function": func} for func in data["functions"]] + del data["functions"] + + skip_sample = False + if "messages" in data: + for i, msg in enumerate(data["messages"]): + if "function_call" in msg: + if "content" in msg: + assert msg["content"] == "" + del msg["content"] + + arguments = json.loads(msg["function_call"]["arguments"]) + msg["function_call"]["arguments"] = json.dumps(arguments) + + msg["tool_calls"] = [{"function": msg.pop("function_call")}] + + for key in content_keys: + if key in msg and msg[key] == "": + if "tool_calls" in msg: + del msg[key] + print( + f"Delete empty '{key}' field in tool call message in line {idx}" + ) + + if all(msg.get(key) in ["", None] for key in content_keys): + # conversation is incorrect + skip_sample = True + + if msg["role"] in ["function", "tool"]: + msg["role"] = "tool" + if "tool_call_id" not in msg: + msg["tool_call_id"] = "".join( + random.choices( + string.ascii_letters + string.digits, k=9 + ) + ) + + # make sure prev + if data["messages"][i - 1]["role"] == "assistant": + prev_msg = data["messages"][i - 1] + if "tool_calls" in prev_msg: + tool_name = prev_msg["tool_calls"][0]["function"][ + "name" + ] + + assert tool_name == msg["name"] + prev_msg["tool_calls"][0]["id"] = msg["tool_call_id"] + + # make sure last message is an assistant message + while ( + len(data["messages"]) > 0 + and data["messages"][-1]["role"] != "assistant" + ): + data["messages"].pop() + + if len(data["messages"]) == 0: + skip_sample = True + + if not skip_sample: + outfile.write(json.dumps(data) + "\n") + else: + print(f"Skip {idx}th sample") + + os.rename(output_file, input_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Reformat a JSONL file.") + parser.add_argument("file", type=str, help="The input JSONL file") + + args = parser.parse_args() + reformat_jsonl(args.file) diff --git a/utils/reformat_data_glaive.py b/utils/reformat_data_glaive.py new file mode 100644 index 0000000..0f908b6 --- /dev/null +++ b/utils/reformat_data_glaive.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import random +import string + + +def reformat_jsonl(input_file): # noqa: C901 + output_file = os.path.splitext(input_file)[0] + "_reformatted.jsonl" + skipped_samples = [] + + with open(input_file, "r") as infile, open(output_file, "w") as outfile: + for i, line in enumerate(infile): + reformat_data = True + data = json.loads(line) + + # Extract function description + try: + function_desc = json.loads(data["function_description"]) + except json.decoder.JSONDecodeError: + function_desc = ( + data["function_description"].replace("\n", "").replace("}{", "},{").replace("\\t", "") + ) + function_desc = "[{" + function_desc[1:-1] + "}]" + function_desc = json.loads(function_desc) + + function_desc = function_desc if isinstance(function_desc, list) else [function_desc] + + # Reformat tools section + if len(function_desc) == 1 and function_desc[0] == {}: + tools = None + else: + tools = [] + for f in function_desc: + if f["parameters"] is None: + f["parameters"] = {} + tools.append({"type": "function", "function": f}) + + messages = [] + + # Process conversations + for idx, msg in enumerate(data["conversations"]): + role = msg["from"] + content = msg["value"] + + if role == "system": + messages.append( + {"role": "system", "content": content.split(" -")[0]} + ) + elif role == "human": + messages.append({"role": "user", "content": content}) + elif role == "function-call": + try: + function_call = json.loads(content) + except json.decoder.JSONDecodeError: + content = content.replace("'", "").replace("\\", "'") + try: + function_call = json.loads(content) + except: # noqa: E722 + skipped_samples.append(str(i)) + reformat_data = False + break + + if not isinstance(function_call, list): + function_calls = [function_call] + else: + function_calls = function_call + + tool_calls = [] + for function_call in function_calls: + assert not isinstance(function_call, list) + tool_call_id = "".join( + random.choices(string.ascii_letters + string.digits, k=9) + ) + + if "arguments" in function_call and not isinstance(function_call["arguments"], str): + function_call["arguments"] = str(function_call["arguments"]) + elif "arguments" not in function_call: + function_call["arguments"] = "" + + tool_calls.append({"id": tool_call_id, "type": "function", "function": function_call}) + + messages.append( + { + "role": "assistant", + "tool_calls": tool_calls + } + ) + elif role == "function-response": + if "tool_calls" not in messages[-1]: + skipped_samples.append(str(i)) + reformat_data = False + break + + assert len(messages[-1]["tool_calls"]) == 1 + tool_call_id = messages[-1]["tool_calls"][0]["id"] + messages.append( + { + "role": "tool", + "content": content, + "tool_call_id": tool_call_id, + } + ) + elif role == "gpt": + messages.append({"role": "assistant", "content": content}) + + output_data = {"messages": messages} + + if tools is not None: + output_data["tools"] = tools + + if reformat_data: + outfile.write(json.dumps(output_data) + "\n") + + os.rename(output_file, input_file) + print( + f"Skipped {len(skipped_samples)} samples ({len(skipped_samples) / i:.2%}). The following samples are incorrectly formated: \n\n {', '.join(skipped_samples)}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Reformat a JSONL file.") + parser.add_argument("file", type=str, help="The input JSONL file") + + args = parser.parse_args() + reformat_jsonl(args.file) diff --git a/utils/validate_data.py b/utils/validate_data.py new file mode 100644 index 0000000..8d8b686 --- /dev/null +++ b/utils/validate_data.py @@ -0,0 +1,366 @@ +import argparse +import json +from pathlib import Path +from typing import Dict + +from mistral_common.exceptions import ( + InvalidAssistantMessageException, + InvalidFunctionCallException, + InvalidMessageStructureException, + InvalidToolSchemaException, + TokenizerException, +) +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from tqdm import tqdm + +from finetune.args import TrainArgs +from finetune.data.dataset import parse_data_sources +from finetune.data.tokenize import ( + ConversationFormatError, + FunctionFormatError, + MessageFormatError, + SampleType, + ToolCallFormatError, + UnrecognizedRoleError, + build_instruct_sample, + get_pretrain_sample, + tokenize, +) + +NUM_GPUS = 8 + +# EXPECTED WPS for batch_size = 32768 per GPU on H100 +EXPECTED_WPS = { + "open-mistral-7b": 5720, + "open-mixtral-8x7b": 2966, + "open-mixtral-8x22b": 1007, +} + +MIN_NUM_JSONL_LINES = 10 +MAX_NUM_JSONL_LINES = 10_000_000 + +MIN_BYTES = 1_000 +MAX_BYTES = 10_000_000_000 # rougly 10 GB + + +def convert_seconds_to_hms(seconds: float) -> str: + hours = int(seconds // 3600) + seconds %= 3600 + minutes = int(seconds // 60) + seconds %= 60 + + seconds = int(seconds) + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + + +def verify_size(jsonl_file: Path): + print(f"Verifying {jsonl_file} ...") + with jsonl_file.open() as f: + num_lines = 0 + num_bytes = 0 + for line in f: + num_lines += 1 + num_bytes += len(line) + + if num_lines < MIN_NUM_JSONL_LINES: + raise ValueError( + f"{jsonl_file} has only {num_lines} conversation which is less than the minimum amount of conversations required per dataset file: {MIN_NUM_JSONL_LINES}. Please make sure that each dataset has at least {MIN_NUM_JSONL_LINES} conversations." + ) + elif num_bytes < MIN_BYTES: + raise ValueError( + f"{jsonl_file} has only {num_bytes} bytes which is less than the minimum amount of bytes required per dataset file: {MIN_BYTES}. Please make sure that each dataset has at least {MIN_BYTES} bytes." + ) + elif num_lines > MAX_NUM_JSONL_LINES: + raise ValueError( + f"{jsonl_file} has {num_lines} conversation which is more than the maximum amount of allowed per dataset file: {MAX_NUM_JSONL_LINES}. Please make sure that each dataset has no more than {MAX_NUM_JSONL_LINES} conversations." + ) + elif num_bytes > MAX_BYTES: + raise ValueError( + f"{jsonl_file} has {num_bytes} bytes which is more than the maximum amount of bytes allowed per dataset file: {MAX_BYTES}. Please make sure that each dataset has no more than {MAX_BYTES} bytes." + ) + + print( + f"Dataset {jsonl_file} is valid. Dataset has {num_lines} conversations amounting to a total of {num_bytes} bytes." + ) + + +def get_train_stats( + num_tokens: Dict[str, int], + datasets_proportion: Dict[str, float], + train_args: TrainArgs, + return_type: str, +): + dataset_tokens = sum(num_tokens.values()) + batch_size = train_args.batch_size * train_args.seq_len * NUM_GPUS + + if Path(train_args.model_id_or_path).is_dir(): + params_config = json.load( + (Path(train_args.model_id_or_path) / "params.json").open() + ) + + if params_config["dim"] == 4096 and params_config.get("moe") is None: + model_id = "open-mistral-7b" + elif params_config["dim"] == 4096 and params_config.get("moe") is not None: + model_id = "open-mixtral-8x7b" + elif params_config["dim"] == 6144: + model_id = "open-mixtral-8x22b" + else: + raise ValueError("Provided model folder seems incorrect.") + else: + model_id = train_args.model_id_or_path + + wps = EXPECTED_WPS[model_id] + + if return_type == "expected": + train_tokens = train_args.max_steps * batch_size + max_steps = train_args.max_steps + num_epochs = train_tokens / dataset_tokens + elif return_type == "recommended": + num_epochs = 3 + max_steps = int(sum(num_tokens.values()) / batch_size * num_epochs) + train_tokens = max_steps * batch_size + else: + raise ValueError( + f"`return_type` is {return_type}, but has to be one of ['expected', 'recommended']" + ) + + expected_time_in_sec = train_tokens / NUM_GPUS / wps + + # Add 5min buffer for loading/init/ckpt/eval + expected_time_in_sec += 300 + + train_tokens_per_dataset = { + k: (train_tokens * v) for k, v in datasets_proportion.items() + } + + return { + "eta": convert_seconds_to_hms(expected_time_in_sec), + "data_tokens": dataset_tokens, + "train_tokens": train_tokens, + "epochs": f"{num_epochs:.2f}", + "max_steps": max_steps, + "data_tokens_per_dataset": {k: f"{v:.1f}" for k, v in num_tokens.items()}, + "train_tokens_per_dataset": { + k: f"{v:.1f}" for k, v in train_tokens_per_dataset.items() + }, + "epochs_per_dataset": { + k: f"{(train_tokens_per_dataset[k] / num_tokens[k]):.1f}" + for k in num_tokens.keys() + }, + } + + +def main(args): + train_args = TrainArgs.load(args.train_yaml) + + yaml_data_errors = [] + conversation_format_errors = [] + message_format_errors = [] + tokenization_errors = [] + + # Check if pretrain can be loaded + # train_pretrain_data = train_args.data.data + data = [("train", train_args.data.data, train_args.data.instruct_data)] + + if train_args.data.eval_instruct_data != "": + data.append(("eval", "", train_args.data.eval_instruct_data)) + + EXPECTED_WPS.keys() + + instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer + + for name, pretrain_file, instruct_file in data: + datasets, weights = parse_data_sources(pretrain_file, instruct_file) + data_types = [d.sample_type for d in datasets] + datasets = [str(d.path) for d in datasets] + + datasets_proportion = dict(zip(datasets, weights)) + num_tokens = {k: 0 for k in datasets_proportion} + + for data_type, dataset in tqdm(zip(data_types, datasets)): + # verify_size(Path(dataset)) + print(f"Validating {dataset} ...") + + corrected_dataset = dataset + ".corrected" + correct_lines = [] + + sub_yaml_data_errors = [] + sub_conversation_format_errors = [] + sub_message_format_errors = [] + sub_tokenization_errors = [] + + # Load the dataset + with open(dataset, "r", encoding="utf-8") as f: + lines = f.readlines() + for idx, line in tqdm(enumerate(lines), total=len(lines)): + try: + data = json.loads(line) + except ValueError as e: + prefix = f"The data in line {idx + 1} of dataset {dataset} is incorrectly formated." + sub_yaml_data_errors.append(prefix + str(e)) + continue + + if data_type == SampleType.PRETRAIN: + # TODO(Patrick) - Get good error message + sample = get_pretrain_sample(data) + else: + try: + sample = build_instruct_sample(data) + + except ( + ConversationFormatError, + UnrecognizedRoleError, + MessageFormatError, + ToolCallFormatError, + FunctionFormatError, + InvalidAssistantMessageException, + InvalidFunctionCallException, + InvalidMessageStructureException, + InvalidToolSchemaException, + ) as e: + prefix = f"The data in line {idx + 1} of dataset {dataset} is incorrectly formated." + if isinstance( + e, (ConversationFormatError, FunctionFormatError) + ): + sub_conversation_format_errors.append(prefix + str(e)) + elif isinstance( + e, + ( + MessageFormatError, + UnrecognizedRoleError, + ToolCallFormatError, + ), + ): + sub_message_format_errors.append(prefix + str(e)) + if isinstance( + e, + ( + InvalidFunctionCallException, + InvalidMessageStructureException, + InvalidAssistantMessageException, + InvalidToolSchemaException, + ), + ): + sub_conversation_format_errors.append(prefix + str(e)) + + continue + try: + tokens = tokenize(sample, instruct_tokenizer).tokens + except TokenizerException as e: + error_message = ( + f"The data in line {idx + 1} of dataset {dataset} could not be tokenized. " + + str(e) + ) + sub_tokenization_errors.append(error_message) + + correct_lines.append(line) + num_tokens[dataset] += len(tokens) + + is_sub_error = ( + len( + sub_yaml_data_errors + + sub_conversation_format_errors + + sub_message_format_errors + + sub_tokenization_errors + ) + > 0 + ) + if is_sub_error and args.create_corrected: + with open(corrected_dataset, "w", encoding="utf-8") as f: + for line in correct_lines: + f.write(line) + + print(f"Saved {corrected_dataset}.") + elif args.create_corrected: + print(f"No error in {dataset} - no need to create a corrected version.") + + yaml_data_errors.extend(sub_yaml_data_errors) + conversation_format_errors.extend(sub_conversation_format_errors) + message_format_errors.extend(sub_message_format_errors) + tokenization_errors.extend(sub_tokenization_errors) + + is_error = ( + len( + yaml_data_errors + + conversation_format_errors + + message_format_errors + + tokenization_errors + ) + > 0 + ) + if is_error: + all_yaml_data_errors = "\n".join(yaml_data_errors) + all_conversation_format_errors = "\n".join(conversation_format_errors) + all_message_format_errors = "\n".join(message_format_errors) + all_tokenization_errors = "\n".join(tokenization_errors) + error_report = f""" + Data error report + ----------------------- \n + The passed datasets contains some errors as listed below. Please make sure to fix these errors in order to start training. + + YAML data load errors: \n\n {all_yaml_data_errors} \n\n + Conversation format errors: \n\n {all_conversation_format_errors} \n\n + Message format errors: \n\n {all_message_format_errors} \n\n + Tokenization errors: \n\n {all_tokenization_errors} \n\n + """ + if args.save_reports: + with open(args.error_report_txt, "w") as f: + f.write(error_report) + + print(error_report) + else: + print("No errors! Data is correctly formated!") + + if name == "train" and not is_error: + expected_stats = get_train_stats( + num_tokens, datasets_proportion, train_args, return_type="expected" + ) + stats = { + "expected": expected_stats, + } + + filenames = ( + f"{instruct_file}" + if pretrain_file == "" + else f"{instruct_file} and {pretrain_file}" + ) + + print( + f"Stats for {filenames} \n {20 * '-'} \n {json.dumps(stats, indent=4)}" + ) + + if args.save_reports: + if name == "train": + with open(args.train_stats_json, "w") as file: + json.dump(stats, file, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Validate jsonl.") + parser.add_argument( + "--train_yaml", + type=str, + help="Path to the data file", + ) + parser.add_argument( + "--error_report_txt", + type=str, + default="data_errors.txt", + help="Path to the error report.", + ) + parser.add_argument( + "--train_stats_json", + type=str, + default="train_stats.json", + help="Path to training statistics json file.", + ) + parser.add_argument( + "--save_reports", action="store_true", help="Save reports to disk" + ) + parser.add_argument( + "--create_corrected", + action="store_true", + help="Skip faulty lines and append all correct lines to `.corrected` datasets.", + ) + args = parser.parse_args() + main(args)