From 4d5771a2970d8cc74cf1ea59ce9e77c318748bd8 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Tue, 30 Apr 2024 03:22:39 +0000 Subject: [PATCH] change argeparse, add training config --- experiments/run_supervised.py | 457 +++++++++++++++----------- llm2vec/experiment_utils.py | 51 --- train_configs/supervised/Mistral.json | 29 ++ 3 files changed, 288 insertions(+), 249 deletions(-) create mode 100644 train_configs/supervised/Mistral.json diff --git a/experiments/run_supervised.py b/experiments/run_supervised.py index 3eb6675..2211622 100644 --- a/experiments/run_supervised.py +++ b/experiments/run_supervised.py @@ -1,7 +1,7 @@ -import argparse import logging -from dataclasses import dataclass +from dataclasses import dataclass, field import os +import sys from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -13,11 +13,14 @@ import transformers from transformers import ( + MODEL_FOR_MASKED_LM_MAPPING, + HfArgumentParser, TrainingArguments, Trainer, TrainerCallback, LlamaConfig, MistralConfig, + set_seed, ) from transformers.trainer_utils import seed_worker @@ -26,17 +29,199 @@ from llm2vec import LLM2Vec from llm2vec.dataset.utils import load_dataset from llm2vec.loss.utils import load_loss -from llm2vec.experiment_utils import ( - generate_experiment_id, - log_commandline_args, - set_seed, - str2bool, - prepare_model_args, -) +from llm2vec.experiment_utils import generate_experiment_id + from tqdm import tqdm transformers.logging.set_verbosity_error() +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = get_logger(__name__, log_level="INFO") +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def prepare_for_tokenization(model, text, pooling_mode="mean"): + if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": + text = ( + "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" + ) + return text + if model.config._name_or_path in [ + "mistralai/Mistral-7B-Instruct-v0.2", + "meta-llama/Llama-2-7b-chat-hf", + ]: + text = "[INST] " + text.strip() + " [/INST]" + if pooling_mode == "eos_token": + if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(model.config, LlamaConfig) or isinstance( + model.config, MistralConfig + ): + text = text.strip() + " " + + return text + + +def initialize_peft( + model, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_modules: Optional[List[str]] = None, +): + if lora_modules is None and model.config.__class__.__name__ in [ + "LlamaConfig", + "MistralConfig", + ]: + lora_modules = [ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + elif lora_modules is None: + raise ValueError("lora_modules must be specified for this model.") + + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_modules, + lora_dropout=lora_dropout, + bias="none", + task_type=None, + ) + + model = get_peft_model(model, config) + print(f"Model's Lora trainable parameters:") + model.print_trainable_parameters() + return model + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + peft_model_name_or_path: Optional[str] = field( + default=None, + metadata={"help": ("The PEFT model checkpoint to add on top of base model.")}, + ) + bidirectional: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention." + ) + }, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + ) + }, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + attn_implementation: Optional[str] = field( + default="sdpa", + metadata={ + "help": ("The attention implementation to use in the model."), + "choices": ["eager", "sdpa", "flash_attention_2"], + }, + ) + pooling_mode: Optional[str] = field( + default="mean", + metadata={ + "help": ("The pooling mode to use in the model."), + "choices": ["mean", "weighted_mean", "eos_token"], + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use. Options: E5"}, + ) + dataset_file_path: Optional[str] = field( + default=None, metadata={"help": "The input training data file or folder."} + ) + # TODO: implement this + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + + +@dataclass +class CustomArguments: + """ + Custom arguments for the script + """ + + lora_dropout: float = field( + default=0.05, metadata={"help": "The dropout rate for lora"} + ) + + lora_r: int = field(default=8, metadata={"help": "The r value for lora"}) + + stop_after_n_steps: int = field( + default=10000, metadata={"help": "Stop training after n steps"} + ) + + experiment_id: Optional[str] = field( + default=None, metadata={"help": "The experiment id"} + ) + + loss_class: Optional[str] = field( + default="HardNegativeNLLLoss", + metadata={ + "help": "The loss class to use for training. Options: HardNegativeNLLLoss" + }, + ) + + loss_scale: float = field( + default=50.0, metadata={"help": "The loss scale for the loss function"} + ) + @dataclass class DefaultCollator: @@ -53,7 +238,9 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: for example in batch: for idx, text in enumerate(example.texts): - text = prepare_for_tokenization(model, text, pooling_mode=model.pooling_mode) + text = prepare_for_tokenization( + self.model, text, pooling_mode=self.model.pooling_mode + ) texts[idx].append(text) labels.append(example.label) labels = torch.tensor(labels) @@ -150,118 +337,24 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(self.args, os.path.join(output_dir, "training_args.bin")) -def prepare_for_tokenization(model, text, pooling_mode="mean"): - if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": - text = ( - "<|start_header_id|>user<|end_header_id|>\n\n" - + text.strip() - + "<|eot_id|>" - ) - return text - if model.config._name_or_path in [ - "mistralai/Mistral-7B-Instruct-v0.2", - "meta-llama/Llama-2-7b-chat-hf", - ]: - text = "[INST] " + text.strip() + " [/INST]" - if pooling_mode == "eos_token": - if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": - text = text.strip() + "<|end_of_text|>" - elif isinstance(model.config, LlamaConfig) or isinstance( - model.config, MistralConfig - ): - text = text.strip() + " " - - return text - -def initialize_peft( - model, - lora_r: int = 8, - lora_alpha: int = 16, - lora_dropout: float = 0.05, - lora_modules: Optional[List[str]] = None, -): - if lora_modules is None and model.config.__class__.__name__ in [ - "LlamaConfig", - "MistralConfig", - ]: - lora_modules = [ - "q_proj", - "v_proj", - "k_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ] - elif lora_modules is None: - raise ValueError("lora_modules must be specified for this model.") - - config = LoraConfig( - r=lora_r, - lora_alpha=lora_alpha, - target_modules=lora_modules, - lora_dropout=lora_dropout, - bias="none", - task_type=None, +def main(): + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments) ) - - model = get_peft_model(model, config) - print(f"Model's Lora trainable parameters:") - model.print_trainable_parameters() - return model - -# TODO: Parse these into JSON, organize same way as MNTP -parser = argparse.ArgumentParser() -parser.add_argument("--model_name", type=str, default="t5-base") -parser.add_argument("--peft_addr", type=str, default=None) -parser.add_argument("--dataset_name", type=str, default="E5") -parser.add_argument( - "--dataset_file_path", - type=str, - default="cache/echo-data", -) -parser.add_argument("--output_dir", type=str, default="output") -parser.add_argument("--train_batch_size", default=16, type=int) -parser.add_argument("--dev_batch_size", default=32, type=int) -parser.add_argument("--max_seq_length", default=384, type=int) -parser.add_argument("--epochs", default=40, type=int) -parser.add_argument("--warmup_ratio", default=0.1, type=float) -parser.add_argument("--warmup_steps", default=0, type=int) -parser.add_argument("--checkpoint_save_steps", default=10000, type=int) -parser.add_argument("--lr", default=2e-5, type=float) -parser.add_argument("--find_unused_parameters", default=False, type=str2bool) -parser.add_argument("--seed", default=42, type=int) -parser.add_argument("--pooling_mode", default="mean", type=str) -parser.add_argument("--checkpoint_save_total_limit", default=0, type=int) -parser.add_argument("--experiment_id", default=None, type=str) -parser.add_argument("--grad_accumulation_steps", default=1, type=int) -parser.add_argument("--lora_r", default=8, type=int) -parser.add_argument("--lora_dropout", default=0.05, type=float) -parser.add_argument("--num_cpu_workers", default=4, type=int) -parser.add_argument("--bidirectional", default=False, type=str2bool) -parser.add_argument("--stop_after_n_steps", default=None, type=int) -parser.add_argument("--fp16", default=False, type=str2bool) -parser.add_argument("--bf16", default=False, type=str2bool) -parser.add_argument("--flash_attention_2", default=False, type=str2bool) -parser.add_argument("--load_in_8bit", default=False, type=str2bool) -parser.add_argument("--load_in_4bit", default=False, type=str2bool) -parser.add_argument("--amp", default=False, type=str2bool) -parser.add_argument("--deepspeed", default=None, type=str) -parser.add_argument("--gradient_checkpointing", default=False, type=str2bool) -parser.add_argument("--loss_cls", default="HardNegativeNLLLoss", type=str) -parser.add_argument("--loss_scale", default=50.0, type=float) - -args = parser.parse_args() - -if __name__ == "__main__": - logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, - ) - logger = get_logger(__name__, log_level="INFO") - - if args.find_unused_parameters: + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, custom_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + ( + model_args, + data_args, + training_args, + custom_args, + ) = parser.parse_args_into_dataclasses() + if training_args.ddp_find_unused_parameters: kwargs = [ DistributedDataParallelKwargs( dim=0, @@ -276,55 +369,44 @@ def initialize_peft( kwargs = [] accelerator = Accelerator(kwargs_handlers=kwargs) - set_seed(args) - - # if accelerator.is_main_process: - log_commandline_args(args, logger.info) - - if args.deepspeed: - assert ( - False - ), "DeepSpeed is not implemented yet. There will be a problem with model saving and loading." - - if args.flash_attention_2: - assert args.fp16 or args.bf16 or args.load_in_8bit or args.load_in_4bit + set_seed(training_args.seed) - gradient_checkpointing_kwargs = None - if args.gradient_checkpointing: - gradient_checkpointing_kwargs = {"use_reentrant": False} + if training_args.gradient_checkpointing: + training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} - if args.experiment_id is not None: - experiment_id = args.experiment_id + if custom_args.experiment_id is not None: + experiment_id = custom_args.experiment_id else: experiment_id = generate_experiment_id( - name=args.dataset_name, + name=data_args.dataset_name, split="train", model_name=( - args.model_name - if "/" not in args.model_name - else args.model_name.split("/")[-1] + model_args.model_name_or_path + if "/" not in model_args.model_name_or_path + else model_args.model_name_or_path.split("/")[-1] ), - pooling_mode=args.pooling_mode, - train_batch_size=args.train_batch_size + pooling_mode=model_args.pooling_mode, + train_batch_size=training_args.per_device_train_batch_size * accelerator.num_processes - * args.grad_accumulation_steps, - max_seq_length=args.max_seq_length, - bidirectional=args.bidirectional, - epochs=args.epochs, - seed=args.seed, - warmup_steps=args.warmup_steps, - lr=args.lr, - lora_r=args.lora_r, + * training_args.gradient_accumulation_steps, + max_seq_length=model_args.max_seq_length, + bidirectional=model_args.bidirectional, + epochs=training_args.num_train_epochs, + seed=training_args.seed, + warmup_steps=training_args.warmup_steps, + lr=training_args.learning_rate, + lora_r=custom_args.lora_r, ) - model_save_path = f"{args.output_dir}/{experiment_id}" + training_args.output_dir = f"{training_args.output_dir}/{experiment_id}" # TODO: can also pass separator arg here train_dataset = load_dataset( - args.dataset_name, + data_args.dataset_name, split="train", - file_path=args.dataset_file_path, - effective_batch_size=args.train_batch_size * accelerator.num_processes, + file_path=data_args.dataset_file_path, + effective_batch_size=training_args.per_device_train_batch_size + * accelerator.num_processes, ) train_examples = [ @@ -336,58 +418,33 @@ def initialize_peft( ) ] - model_args = prepare_model_args( - bf16=args.bf16, - fp16=args.fp16, - flash_attention_2=args.flash_attention_2, - load_in_8bit=args.load_in_8bit, - load_in_4bit=args.load_in_4bit, + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) ) - model = LLM2Vec.from_pretrained( - base_model_name_or_path=args.model_name, - enable_bidirectional=args.bidirectional, - peft_model_name_or_path=args.peft_addr, + base_model_name_or_path=model_args.model_name_or_path, + enable_bidirectional=model_args.bidirectional, + peft_model_name_or_path=model_args.peft_model_name_or_path, merge_peft=True, - pooling_mode=args.pooling_mode, - max_length=args.max_seq_length, - **model_args, + pooling_mode=model_args.pooling_mode, + max_length=model_args.max_seq_length, + torch_dtype=torch_dtype, + attn_implementation=model_args.attn_implementation, ) # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model model.model = initialize_peft( model.model, - lora_r=args.lora_r, - lora_alpha=2 * args.lora_r, - lora_dropout=args.lora_dropout, + lora_r=custom_args.lora_r, + lora_alpha=2 * custom_args.lora_r, + lora_dropout=custom_args.lora_dropout, ) tokenizer = model.tokenizer - train_loss = load_loss(args.loss_cls, scale=args.loss_scale) - - training_args = TrainingArguments( - output_dir=model_save_path, - num_train_epochs=args.epochs, - seed=args.seed, - per_device_train_batch_size=args.train_batch_size, - gradient_accumulation_steps=args.grad_accumulation_steps, - learning_rate=args.lr, - warmup_ratio=args.warmup_ratio, - warmup_steps=args.warmup_steps, - logging_dir=model_save_path + "/logs", - logging_steps=50, - save_steps=args.checkpoint_save_steps, - save_total_limit=args.checkpoint_save_total_limit, - remove_unused_columns=False, - disable_tqdm=False, - save_only_model=True, - fp16=args.amp, - gradient_checkpointing=args.gradient_checkpointing, - gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, - deepspeed=args.deepspeed, - ddp_find_unused_parameters=args.find_unused_parameters, - ) + train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale) data_collator = DefaultCollator(model) @@ -400,7 +457,11 @@ def initialize_peft( loss_function=train_loss, ) - if args.stop_after_n_steps is not None: - trainer.add_callback(StopTrainingCallback(args.stop_after_n_steps)) + if custom_args.stop_after_n_steps is not None: + trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps)) trainer.train() + + +if __name__ == "__main__": + main() diff --git a/llm2vec/experiment_utils.py b/llm2vec/experiment_utils.py index 42a022b..c2b97a7 100644 --- a/llm2vec/experiment_utils.py +++ b/llm2vec/experiment_utils.py @@ -1,16 +1,4 @@ -import gzip -import random import re -import shutil -import urllib.request -from pathlib import Path -import numpy as np -import torch -import json -import argparse - -from tqdm.auto import tqdm -from tqdm.utils import CallbackIOWrapper def generate_experiment_id( @@ -98,42 +86,3 @@ def parse_experiment_id(experiment_id): } return result - -def log_commandline_args(args, logger=print): - for arg in vars(args): - logger(f" - {arg}: {getattr(args, arg)}") - -def set_seed(args): - seed = args.seed - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.distributed.get_world_size() > 0: - torch.cuda.manual_seed_all(seed) - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - -def prepare_model_args(**kwargs): - args = {} - for k, v in kwargs.items(): - if k == "flash_attention_2" and v: - args["attn_implementation"] = "flash_attention_2" - if k == "bf16" and v: - args["torch_dtype"] = torch.bfloat16 - if k == "fp16" and v: - args["torch_dtype"] = torch.float16 - if (k == "load_in_8bit" or k == "load_in_4bit") and v: - args[k] = v - if k == "bidirectional": - args[k] = v - - return args diff --git a/train_configs/supervised/Mistral.json b/train_configs/supervised/Mistral.json new file mode 100644 index 0000000..e0859d1 --- /dev/null +++ b/train_configs/supervised/Mistral.json @@ -0,0 +1,29 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + "bidirectional": true, + "pooling_mode": "mean", + "dataset_name": "E5", + "dataset_file_path": "cache/echo-data", + "remove_unused_columns": false, + "learning_rate": 2e-4, + "num_train_epochs": 3, + "warmup_steps": 300, + "per_device_train_batch_size": 64, + "per_device_eval_batch_size": 64, + "gradient_accumulation_steps": 1, + "do_train": true, + "disable_tqdm": false, + "max_seq_length": 512, + "overwrite_output_dir": true, + "output_dir": "output/supervised/Mistral-7B-Instruct-v0.2", + "logging_steps": 50, + "save_steps": 200, + "save_only_model": true, + "stop_after_n_steps": 1000, + "lora_r": 16, + "gradient_checkpointing": true, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "seed": 42 +} \ No newline at end of file