From 290e9d004b5b0c3639b423e8e30c0baff3ec3437 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Sun, 28 Apr 2024 00:44:28 +0000 Subject: [PATCH 1/9] first commit for supervised --- experiments/run_supervised.py | 305 ++++++++++++++++++++++++++++ llm2vec/dataset/E5Data.py | 169 +++++++++++++++ llm2vec/dataset/__init__.py | 1 + llm2vec/dataset/dataset.py | 46 +++++ llm2vec/dataset/utils.py | 25 +++ llm2vec/experiment_utils.py | 139 +++++++++++++ llm2vec/llm2vec.py | 3 + llm2vec/loss/HardNegativeNLLLoss.py | 47 +++++ llm2vec/loss/__init__.py | 1 + llm2vec/loss/utils.py | 84 ++++++++ 10 files changed, 820 insertions(+) create mode 100644 experiments/run_supervised.py create mode 100644 llm2vec/dataset/E5Data.py create mode 100644 llm2vec/dataset/__init__.py create mode 100644 llm2vec/dataset/dataset.py create mode 100644 llm2vec/dataset/utils.py create mode 100644 llm2vec/experiment_utils.py create mode 100644 llm2vec/loss/HardNegativeNLLLoss.py create mode 100644 llm2vec/loss/__init__.py create mode 100644 llm2vec/loss/utils.py diff --git a/experiments/run_supervised.py b/experiments/run_supervised.py new file mode 100644 index 0000000..1bddc1e --- /dev/null +++ b/experiments/run_supervised.py @@ -0,0 +1,305 @@ +import argparse +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import DataLoader, SequentialSampler + +from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate.logging import get_logger + +import transformers +from transformers import ( + TrainingArguments, + Trainer, + TrainerCallback, +) +from transformers.trainer_utils import seed_worker + +from llm2vec import LLM2Vec +from llm2vec.dataset.utils import load_dataset +from llm2vec.loss.utils import load_loss +from llm2vec.experiment_utils import ( + generate_experiment_id, + log_commandline_args, + set_seed, + str2bool, + prepare_model_args, +) +from tqdm import tqdm + +transformers.logging.set_verbosity_error() + + +@dataclass +class DefaultCollator: + model: LLM2Vec + + def __init__(self, model: LLM2Vec) -> None: + self.model = model + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + batch = features + num_texts = len(batch[0].texts) + texts = [[] for _ in range(num_texts)] + labels = [] + + for example in batch: + for idx, text in enumerate(example.texts): + text = self.model.prepare_for_tokenization(text) + texts[idx].append(text) + labels.append(example.label) + labels = torch.tensor(labels) + + sentence_features = [] + for idx in range(num_texts): + tokenized = self.model.tokenize(texts[idx]) + sentence_features.append(tokenized) + + return sentence_features, labels + + +class StopTrainingCallback(TrainerCallback): + def __init__(self, stop_after_n_steps: int): + self.stop_after_n_steps = stop_after_n_steps + + def on_step_end(self, args, state, control, **kwargs): + if state.global_step >= self.stop_after_n_steps: + control.should_training_stop = True + + +class LLM2VecSupervisedTrainer(Trainer): + + def compute_loss( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + features, labels = inputs + loss = model(features, labels) + if return_outputs: + output = torch.cat( + [model(row)["sentence_embedding"][:, None] for row in features], dim=1 + ) + return loss, output + return loss + + def get_train_dataloader(self) -> DataLoader: + # Copying most of the code from the parent class, changing the sampler to SequentialSampler + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + + data_collator = self._get_collator_with_removed_columns( + data_collator, description="training" + ) + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + # Changing from random sampler to sequential sampler + dataloader_params["sampler"] = SequentialSampler(train_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name", type=str, default="t5-base") +parser.add_argument("--peft_addr", type=str, default=None) +parser.add_argument("--dataset_name", type=str, default="E5") +parser.add_argument( + "--dataset_file_path", + type=str, + default="cache/echo-data", +) +parser.add_argument("--output_dir", type=str, default="output") +parser.add_argument("--train_batch_size", default=16, type=int) +parser.add_argument("--dev_batch_size", default=32, type=int) +parser.add_argument("--max_seq_length", default=384, type=int) +parser.add_argument("--epochs", default=40, type=int) +parser.add_argument("--warmup_ratio", default=0.1, type=float) +parser.add_argument("--warmup_steps", default=0, type=int) +parser.add_argument("--checkpoint_save_steps", default=10000, type=int) +parser.add_argument("--lr", default=2e-5, type=float) +parser.add_argument("--find_unused_parameters", default=False, type=str2bool) +parser.add_argument("--seed", default=42, type=int) +parser.add_argument("--pooling_mode", default="mean", type=str) +parser.add_argument("--checkpoint_save_total_limit", default=0, type=int) +parser.add_argument("--experiment_id", default=None, type=str) +parser.add_argument("--grad_accumulation_steps", default=1, type=int) +parser.add_argument("--lora_r", default=8, type=int) +parser.add_argument("--lora_alpha", default=16, type=int) +parser.add_argument("--lora_dropout", default=0.05, type=float) +parser.add_argument("--num_cpu_workers", default=4, type=int) +parser.add_argument("--bidirectional", default=False, type=str2bool) +parser.add_argument("--stop_after_n_steps", default=None, type=int) +parser.add_argument("--fp16", default=False, type=str2bool) +parser.add_argument("--bf16", default=False, type=str2bool) +parser.add_argument("--flash_attention_2", default=False, type=str2bool) +parser.add_argument("--load_in_8bit", default=False, type=str2bool) +parser.add_argument("--load_in_4bit", default=False, type=str2bool) +parser.add_argument("--amp", default=False, type=str2bool) +parser.add_argument("--deepspeed", default=None, type=str) +parser.add_argument("--gradient_checkpointing", default=False, type=str2bool) +parser.add_argument("--loss_cls", default="HardNegativeNLLLoss", type=str) +parser.add_argument("--loss_scale", default=50.0, type=float) + +args = parser.parse_args() + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + ) + logger = get_logger(__name__, log_level="INFO") + + if args.find_unused_parameters: + kwargs = [ + DistributedDataParallelKwargs( + dim=0, + broadcast_buffers=True, + bucket_cap_mb=25, + find_unused_parameters=True, + check_reduction=False, + gradient_as_bucket_view=False, + ) + ] + else: + kwargs = [] + accelerator = Accelerator(kwargs_handlers=kwargs) + + set_seed(args) + + # if accelerator.is_main_process: + log_commandline_args(args, logger.info) + + if args.deepspeed: + assert ( + False + ), "DeepSpeed is not implemented yet. There will be a problem with model saving and loading." + + if args.flash_attention_2: + assert args.fp16 or args.bf16 or args.load_in_8bit or args.load_in_4bit + + gradient_checkpointing_kwargs = None + if args.gradient_checkpointing: + gradient_checkpointing_kwargs = {"use_reentrant": False} + + if args.experiment_id is not None: + experiment_id = args.experiment_id + else: + experiment_id = generate_experiment_id( + name=args.dataset_name, + split="train", + model_name=( + args.model_name + if "/" not in args.model_name + else args.model_name.split("/")[-1] + ), + pooling_mode=args.pooling_mode, + train_batch_size=args.train_batch_size + * accelerator.num_processes + * args.grad_accumulation_steps, + max_seq_length=args.max_seq_length, + bidirectional=args.bidirectional, + epochs=args.epochs, + seed=args.seed, + warmup_steps=args.warmup_steps, + lr=args.lr, + lora_r=args.lora_r, + ) + + model_save_path = f"{args.output_dir}/{experiment_id}" + + # TODO: can also pass separator arg here + train_dataset = load_dataset( + args.dataset_name, + split="train", + file_path=args.dataset_file_path, + effective_batch_size=args.train_batch_size * accelerator.num_processes, + shuffle_individual_datasets=args.shuffle_individual_datasets, + ) + + train_examples = [ + train_dataset[i] + for i in tqdm( + range(len(train_dataset)), + desc="Loading train examples...", + disable=not accelerator.is_main_process, + ) + ] + + # Load LLM2Vec Model, TODO: Enable bidirectional + + model_args = prepare_model_args( + bf16=args.bf16, + fp16=args.fp16, + flash_attention_2=args.flash_attention_2, + load_in_8bit=args.load_in_8bit, + load_in_4bit=args.load_in_4bit, + ) + + # TODO: Enable bidirectional, make trainable as an option + model = LLM2Vec.from_pretrained( + base_model_name_or_path=args.model_name, + peft_model_name_or_path=args.peft_addr, + pooling_mode=args.pooling_mode, + max_seq_length=args.max_seq_length, + **model_args, + ) + + tokenizer = model.tokenizer + + train_loss = load_loss(args.loss_cls, scale=args.loss_scale) + + training_args = TrainingArguments( + output_dir=model_save_path, + num_train_epochs=args.epochs, + seed=args.seed, + per_device_train_batch_size=args.train_batch_size, + gradient_accumulation_steps=args.grad_accumulation_steps, + learning_rate=args.lr, + warmup_ratio=args.warmup_ratio, + warmup_steps=args.warmup_steps, + logging_dir=model_save_path + "/logs", + logging_steps=50, + save_steps=args.checkpoint_save_steps, + save_total_limit=args.checkpoint_save_total_limit, + remove_unused_columns=False, + disable_tqdm=False, + save_only_model=True, + fp16=args.amp, + gradient_checkpointing=args.gradient_checkpointing, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + deepspeed=args.deepspeed, + ddp_find_unused_parameters=args.find_unused_parameters, + ) + + data_collator = DefaultCollator(model) + + trainer = LLM2VecSupervisedTrainer( + model=model, + args=training_args, + train_dataset=train_examples, + data_collator=data_collator, + tokenizer=tokenizer, + ) + + if args.stop_after_n_steps is not None: + trainer.add_callback(StopTrainingCallback(args.stop_after_n_steps)) + + trainer.train() diff --git a/llm2vec/dataset/E5Data.py b/llm2vec/dataset/E5Data.py new file mode 100644 index 0000000..69e3ca3 --- /dev/null +++ b/llm2vec/dataset/E5Data.py @@ -0,0 +1,169 @@ +import json +import random +import os + +from .dataset import DataSample, TrainSample, Dataset +from accelerate.logging import get_logger + +logger = get_logger(__name__, log_level="INFO") + +datasets_list = [ + "allnli_split1", + "allnli_split2", + "dureader", + "eli5_question_answer", + "fever", + "hotpot_qa", + "miracl", + "mrtydi", + "msmarco_passage", + "msmarco_document", + "nq", + "quora_duplicates_split1", + "quora_duplicates_split2", + "squad", + "t2ranking", + "trivia_qa", +] + +E5_EMBEDDING_PROMPTS = { + "allnli_split1": "Given a premise, retrieve a hypothesis that is entailed by the premise", + "allnli_split2": "Retrieve semantically similar text", + "dureader": "Given a Chinese search query, retrieve web passages that answer the question", + "eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum", + "fever": "Given a claim, retrieve documents that support or refute the claim", + "hotpot_qa": "Given a multi-hop question, retrieve documents that can help answer the question", + "miracl": "Given a question, retrieve Wikipedia passages that answer the question", + "mrtydi": "Given a question, retrieve Wikipedia passages that answer the question", + "msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query", + "msmarco_document": "Given a web search query, retrieve relevant documents that answer the query", + "nq": "Given a question, retrieve Wikipedia passages that answer the question", + "quora_duplicates_split1": "Given a question, retrieve questions that are semantically equivalent to the given question", + "quora_duplicates_split2": "Find questions that have the same meaning as the input question", + "squad": "Retrieve Wikipedia passages that answer the question", + "t2ranking": "Given a Chinese search query, retrieve web passages that answer the question", + "trivia_qa": "Retrieve Wikipedia passages that answer the question", +} + + +class E5Data(Dataset): + def __init__( + self, + dataset_name: str = "E5", + split: str = "validation", + file_path: str = "cache/echo-data", + effective_batch_size: int = 32, + shuffle_individual_datasets: bool = True, + separator: str = "!@#$%^&*()", + ): + self.dataset_name = dataset_name + self.split = split + self.effective_batch_size = effective_batch_size + self.shuffle_individual_datasets = shuffle_individual_datasets + self.separator = separator + + self.data = [] + self.load_data(file_path) + + def __len__(self): + return len(self.data) + + def load_data(self, file_path: str = None): + logger.info(f"Loading E5 data from {file_path}...") + # file path is actually a directory + + data_map = {} + all_samples = [] + id_ = 0 + for dataset in datasets_list: + logger.info(f"Loading dataset {dataset}...") + if dataset not in data_map: + data_map[dataset] = [] + with open(os.path.join(file_path, f"{dataset}.jsonl"), "r") as f: + dataset_samples = f.readlines() + + dataset_samples = [json.loads(d) for d in dataset_samples] + + for sample in dataset_samples: + query = ( + f"{E5_EMBEDDING_PROMPTS[dataset]}; " + + self.separator + + sample["query"] + ) + if dataset in [ + "allnli_split2", + "quora_duplicates_split1", + "quora_duplicates_split2", + ]: + pos = ( + f"{E5_EMBEDDING_PROMPTS[dataset]}; " + + self.separator + + sample["positive"] + ) + neg = ( + f"{E5_EMBEDDING_PROMPTS[dataset]}; " + + self.separator + + sample["negative"] + ) + else: + pos = self.separator + sample["positive"] + neg = self.separator + sample["negative"] + + data_map[dataset].append(id_) + + all_samples.append( + DataSample( + id_=id_, + query=query, + positive=pos, + negative=neg, + task_name=dataset, + ) + ) + id_ += 1 + + # combine split1 and split2 + new_data_map = {} + for dataset in data_map: + new_dataset = dataset.replace("_split1", "").replace("_split2", "") + if new_dataset not in new_data_map: + new_data_map[new_dataset] = [] + new_data_map[new_dataset] += data_map[dataset] + data_map = new_data_map + + if self.shuffle_individual_datasets: + for task, samples in data_map.items(): + random.shuffle(samples) + + datasets = list(data_map.keys()) + + logger.info( + f"Batching Echo data properly for effective batch size of {self.effective_batch_size}..." + ) + all_batches = [] + for dataset in datasets: + dataset_samples = data_map[dataset] + for i in range(0, len(dataset_samples), self.effective_batch_size): + batch = dataset_samples[i : i + self.effective_batch_size] + if len(batch) == self.effective_batch_size: + all_batches.append(batch) + else: + logger.info(f"Skip 1 batch for dataset {dataset}.") + random.shuffle(all_batches) + + final_idx_order = [] + for batch in all_batches: + for idx in batch: + final_idx_order.append(idx) + + self.data = [all_samples[idx] for idx in final_idx_order] + logger.info(f"Loaded {len(self.data)} samples.") + + def __getitem__(self, index): + sample = self.data[index] + if self.split == "train": + return TrainSample( + texts=[sample.query, sample.positive, sample.negative], label=1.0 + ) + elif self.split == "validation": + assert False, "E5Data does not have a validation split." diff --git a/llm2vec/dataset/__init__.py b/llm2vec/dataset/__init__.py new file mode 100644 index 0000000..6e1721d --- /dev/null +++ b/llm2vec/dataset/__init__.py @@ -0,0 +1 @@ +from .E5Data import E5Data \ No newline at end of file diff --git a/llm2vec/dataset/dataset.py b/llm2vec/dataset/dataset.py new file mode 100644 index 0000000..166d180 --- /dev/null +++ b/llm2vec/dataset/dataset.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +from typing import Union, List + +import torch + +@dataclass +class DataSample: + id_: int + query: str + positive: str + negative: str + task_name: str + +class TrainSample: + """ + Structure for one input example with texts, the label and a unique id + """ + def __init__(self, guid: str = '', texts: List[str] = None, label: Union[int, float] = 0): + """ + Creates one TrainSample with the given texts, guid and label + + + :param guid + id for the example + :param texts + the texts for the example. + :param label + the label for the example + """ + self.guid = guid + self.texts = texts + self.label = label + + def __str__(self): + return " label: {}, texts: {}".format(str(self.label), "; ".join(self.texts)) + + +class Dataset(torch.utils.data.Dataset): + def load_data(self, file_path: str = None): + raise NotImplementedError() + + def __getitem__(self, index): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() diff --git a/llm2vec/dataset/utils.py b/llm2vec/dataset/utils.py new file mode 100644 index 0000000..4368177 --- /dev/null +++ b/llm2vec/dataset/utils.py @@ -0,0 +1,25 @@ +from ..dataset import E5Data + + +def load_dataset(dataset_name, split="validation", file_path=None, **kwargs): + """ + Loads a dataset by name. + + Args: + dataset_name (str): Name of the dataset to load. + split (str): Split of the dataset to load. + file_path (str): Path to the dataset file. + """ + dataset_mapping = { + "E5": E5Data, + } + + if dataset_name not in dataset_mapping: + raise NotImplementedError(f"Dataset name {dataset_name} not supported.") + + if split not in ["train", "validation", "test"]: + raise NotImplementedError(f"Split {split} not supported.") + + return dataset_mapping[dataset_name]( + split=split, file_path=file_path, **kwargs + ) diff --git a/llm2vec/experiment_utils.py b/llm2vec/experiment_utils.py new file mode 100644 index 0000000..42a022b --- /dev/null +++ b/llm2vec/experiment_utils.py @@ -0,0 +1,139 @@ +import gzip +import random +import re +import shutil +import urllib.request +from pathlib import Path +import numpy as np +import torch +import json +import argparse + +from tqdm.auto import tqdm +from tqdm.utils import CallbackIOWrapper + + +def generate_experiment_id( + name, + split, + model_name, + pooling_mode, + train_batch_size, + max_seq_length, + bidirectional, + epochs, + seed, + warmup_steps, + lr, + lora_r, +): + experiment_id = name + "_" + split + + if isinstance(model_name, str): + experiment_id += f"_m-{model_name}" + if isinstance(pooling_mode, str): + experiment_id += f"_p-{pooling_mode}" + if isinstance(train_batch_size, int): + experiment_id += f"_b-{train_batch_size}" + if isinstance(max_seq_length, int): + experiment_id += f"_l-{max_seq_length}" + if isinstance(bidirectional, bool): + experiment_id += f"_bidirectional-{bidirectional}" + if isinstance(epochs, int): + experiment_id += f"_e-{epochs}" + if isinstance(seed, int): + experiment_id += f"_s-{seed}" + if isinstance(warmup_steps, int): + experiment_id += f"_w-{warmup_steps}" + if isinstance(lr, float): + experiment_id += f"_lr-{lr}" + if isinstance(lora_r, int): + experiment_id += f"_lora_r-{lora_r}" + + return experiment_id + +def parse_experiment_id(experiment_id): + """ + Parses experiment identifier into key-value pairs. + + Args: + experiment_id (str): Unique experiment identifier to parse. + + Returns: + dict: Dictionary containing the parsed key-value pairs. + """ + regex, post_regex = "", "" + if "/" in experiment_id: + regex = "([A-Za-z0-9-_./]*)/" + post_regex = "/([A-Za-z0-9-_./]*)" + regex += "([A-Za-z0-9-_.]+)" + regex += "_m-([A-Z-a-z0-9-_.]+)" + regex += "_p-([A-Z-a-z0-9-_.]+)" + regex += "_b-(\d+)" + regex += "_l-(\d+)" + regex += "_bidirectional-([A-Z-a-z0-9-_.]+)" + regex += "_e-(\d+)" + regex += "_s-(\d+)" + regex += "_w-(\d+)" + regex += "_lr-([A-Z-a-z0-9-_.]+)" + regex += "_lora_r-(\d+)" + regex += post_regex + + parts = re.match(regex, experiment_id).groups() + if post_regex != "": + parts = parts[1:-1] + + result = { + "name": parts[0], + "model_name_or_path": parts[1], + "pooling_mode": parts[2], + "train_batch_size": int(parts[3]), + "max_seq_length": int(parts[4]), + "bidirectional": parts[5] == "True", + "epochs": int(parts[6]), + "seed": int(parts[7]), + "warmup_steps": int(parts[8]), + "lr": float(parts[9]), + "lora_r": int(parts[10]), + } + + return result + +def log_commandline_args(args, logger=print): + for arg in vars(args): + logger(f" - {arg}: {getattr(args, arg)}") + +def set_seed(args): + seed = args.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.distributed.get_world_size() > 0: + torch.cuda.manual_seed_all(seed) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def prepare_model_args(**kwargs): + args = {} + for k, v in kwargs.items(): + if k == "flash_attention_2" and v: + args["attn_implementation"] = "flash_attention_2" + if k == "bf16" and v: + args["torch_dtype"] = torch.bfloat16 + if k == "fp16" and v: + args["torch_dtype"] = torch.float16 + if (k == "load_in_8bit" or k == "load_in_4bit") and v: + args[k] = v + if k == "bidirectional": + args[k] = v + + return args diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index 5fddb79..f9efcd7 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -403,3 +403,6 @@ def resize_token_embeddings( return self.model.resize_token_embeddings( new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of ) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) diff --git a/llm2vec/loss/HardNegativeNLLLoss.py b/llm2vec/loss/HardNegativeNLLLoss.py new file mode 100644 index 0000000..ddf9245 --- /dev/null +++ b/llm2vec/loss/HardNegativeNLLLoss.py @@ -0,0 +1,47 @@ +import torch +from torch import nn, Tensor +from .utils import cos_sim, mismatched_sizes_all_gather + +class HardNegativeNLLLoss(): + def __init__( + self, + scale: float = 20.0, + similarity_fct = cos_sim, + ): + self.scale = scale + self.similarity_fct = similarity_fct + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def __call__( + self, + q_reps: Tensor, + d_reps_pos: Tensor, + d_reps_neg: Tensor = None, + ): + if d_reps_neg is None: + d_reps_neg = d_reps_pos[:0, :] + + if torch.distributed.is_initialized(): + full_d_reps_pos = mismatched_sizes_all_gather(d_reps_pos) + full_d_reps_pos = torch.cat(full_d_reps_pos) + + full_q_reps = mismatched_sizes_all_gather(q_reps) + full_q_reps = torch.cat(full_q_reps) + + full_d_reps_neg = mismatched_sizes_all_gather(d_reps_neg) + full_d_reps_neg = torch.cat(full_d_reps_neg) + else: + full_d_reps_pos = d_reps_pos + full_q_reps = q_reps + full_d_reps_neg = d_reps_neg + + d_reps = torch.cat([full_d_reps_pos, full_d_reps_neg], dim=0) + scores = self.similarity_fct(full_q_reps, d_reps) * self.scale + labels = ( + torch.tensor( + range(len(scores)), dtype=torch.long, device=scores.device + ) + ) + + loss = self.cross_entropy_loss(scores, labels) + return loss diff --git a/llm2vec/loss/__init__.py b/llm2vec/loss/__init__.py new file mode 100644 index 0000000..6c342d6 --- /dev/null +++ b/llm2vec/loss/__init__.py @@ -0,0 +1 @@ +from .HardNegativeNLLLoss import HardNegativeNLLLoss \ No newline at end of file diff --git a/llm2vec/loss/utils.py b/llm2vec/loss/utils.py new file mode 100644 index 0000000..944e019 --- /dev/null +++ b/llm2vec/loss/utils.py @@ -0,0 +1,84 @@ +import torch +from torch import Tensor +from . import HardNegativeNLLLoss + +def load_loss(loss_class, *args, **kwargs): + if loss_class == "HardNegativeNLLLoss": + loss_cls = HardNegativeNLLLoss + else: + raise ValueError(f"Unknown loss class {loss_class}") + return loss_cls(*args, **kwargs) + +# from https://github.com/vlkit/vlkit/blob/master/vlkit/ops/distributed.py +class AllGather(torch.autograd.Function): + """ + all_gather with gradient back-propagation + """ + @staticmethod + def forward(ctx, tensor_list, tensor, group, async_op): + torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_list): + grad_list = list(grad_list) + rank = torch.distributed.get_rank() + + dist_ops = [ + torch.distributed.reduce(grad_list[i], i, async_op=True) for i in range(torch.distributed.get_world_size()) + ] + + for op in dist_ops: + op.wait() + + return None, grad_list[rank], None, None + +all_gather_with_grad = AllGather.apply + +def cos_sim(a: Tensor, b: Tensor): + """ + Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = torch.nn.functional.normalize(a, p=2, dim=1) + b_norm = torch.nn.functional.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + +def mismatched_sizes_all_gather(tensor: Tensor, group=None, async_op=False, mismatched_axis=0): + # all_gather doesn't support tensor lists where the first dimension is mismatched. This does. + assert torch.distributed.is_initialized(), "torch.distributed not initialized" + world_size = torch.distributed.get_world_size() + # let's get the sizes for everyone + mismatched_sizes = torch.tensor([tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda") + sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] + torch.distributed.all_gather(sizes, mismatched_sizes, group=group, async_op=async_op) + sizes = torch.cat(sizes).cpu().tolist() + # now pad to the max dim-0 size + max_size = max(sizes) + padded = torch.zeros((*tensor.shape[:mismatched_axis], max_size, *tensor.shape[mismatched_axis+1:]), + device=tensor.device, dtype=tensor.dtype) + # selects the place where we're adding information + padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) + padded_to_fill[...] = tensor + # gather the padded tensors + tensor_list = [torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) for _ in range(world_size)] + all_gather_with_grad(tensor_list, padded, group, async_op) + # trim off the padding + for rank in range(world_size): + # checks that the rest is 0 + assert not tensor_list[rank].narrow(mismatched_axis, sizes[rank], padded.shape[mismatched_axis]-sizes[rank]).count_nonzero().is_nonzero(), \ + "This would remove non-padding information" + tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) + return tensor_list \ No newline at end of file From 0cac15a509c7268bab991d95e9f7359cea29c77e Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Mon, 29 Apr 2024 15:19:36 +0000 Subject: [PATCH 2/9] rearranging loss module --- llm2vec/loss/HardNegativeNLLLoss.py | 2 +- llm2vec/loss/loss_utils.py | 75 +++++++++++++++++++++++++++ llm2vec/loss/utils.py | 78 +---------------------------- 3 files changed, 77 insertions(+), 78 deletions(-) create mode 100644 llm2vec/loss/loss_utils.py diff --git a/llm2vec/loss/HardNegativeNLLLoss.py b/llm2vec/loss/HardNegativeNLLLoss.py index ddf9245..c49b69d 100644 --- a/llm2vec/loss/HardNegativeNLLLoss.py +++ b/llm2vec/loss/HardNegativeNLLLoss.py @@ -1,6 +1,6 @@ import torch from torch import nn, Tensor -from .utils import cos_sim, mismatched_sizes_all_gather +from .loss_utils import cos_sim, mismatched_sizes_all_gather class HardNegativeNLLLoss(): def __init__( diff --git a/llm2vec/loss/loss_utils.py b/llm2vec/loss/loss_utils.py new file mode 100644 index 0000000..5d56fb5 --- /dev/null +++ b/llm2vec/loss/loss_utils.py @@ -0,0 +1,75 @@ +import torch +from torch import Tensor + +class AllGather(torch.autograd.Function): + """ + all_gather with gradient back-propagation + """ + @staticmethod + def forward(ctx, tensor_list, tensor, group, async_op): + torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_list): + grad_list = list(grad_list) + rank = torch.distributed.get_rank() + + dist_ops = [ + torch.distributed.reduce(grad_list[i], i, async_op=True) for i in range(torch.distributed.get_world_size()) + ] + + for op in dist_ops: + op.wait() + + return None, grad_list[rank], None, None + +all_gather_with_grad = AllGather.apply + +def cos_sim(a: Tensor, b: Tensor): + """ + Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = torch.nn.functional.normalize(a, p=2, dim=1) + b_norm = torch.nn.functional.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + +def mismatched_sizes_all_gather(tensor: Tensor, group=None, async_op=False, mismatched_axis=0): + # all_gather doesn't support tensor lists where the first dimension is mismatched. This does. + assert torch.distributed.is_initialized(), "torch.distributed not initialized" + world_size = torch.distributed.get_world_size() + # let's get the sizes for everyone + mismatched_sizes = torch.tensor([tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda") + sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] + torch.distributed.all_gather(sizes, mismatched_sizes, group=group, async_op=async_op) + sizes = torch.cat(sizes).cpu().tolist() + # now pad to the max dim-0 size + max_size = max(sizes) + padded = torch.zeros((*tensor.shape[:mismatched_axis], max_size, *tensor.shape[mismatched_axis+1:]), + device=tensor.device, dtype=tensor.dtype) + # selects the place where we're adding information + padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) + padded_to_fill[...] = tensor + # gather the padded tensors + tensor_list = [torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) for _ in range(world_size)] + all_gather_with_grad(tensor_list, padded, group, async_op) + # trim off the padding + for rank in range(world_size): + # checks that the rest is 0 + assert not tensor_list[rank].narrow(mismatched_axis, sizes[rank], padded.shape[mismatched_axis]-sizes[rank]).count_nonzero().is_nonzero(), \ + "This would remove non-padding information" + tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) + return tensor_list diff --git a/llm2vec/loss/utils.py b/llm2vec/loss/utils.py index 944e019..8855b27 100644 --- a/llm2vec/loss/utils.py +++ b/llm2vec/loss/utils.py @@ -1,6 +1,4 @@ -import torch -from torch import Tensor -from . import HardNegativeNLLLoss +from .HardNegativeNLLLoss import HardNegativeNLLLoss def load_loss(loss_class, *args, **kwargs): if loss_class == "HardNegativeNLLLoss": @@ -8,77 +6,3 @@ def load_loss(loss_class, *args, **kwargs): else: raise ValueError(f"Unknown loss class {loss_class}") return loss_cls(*args, **kwargs) - -# from https://github.com/vlkit/vlkit/blob/master/vlkit/ops/distributed.py -class AllGather(torch.autograd.Function): - """ - all_gather with gradient back-propagation - """ - @staticmethod - def forward(ctx, tensor_list, tensor, group, async_op): - torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) - return tuple(tensor_list) - - @staticmethod - def backward(ctx, *grad_list): - grad_list = list(grad_list) - rank = torch.distributed.get_rank() - - dist_ops = [ - torch.distributed.reduce(grad_list[i], i, async_op=True) for i in range(torch.distributed.get_world_size()) - ] - - for op in dist_ops: - op.wait() - - return None, grad_list[rank], None, None - -all_gather_with_grad = AllGather.apply - -def cos_sim(a: Tensor, b: Tensor): - """ - Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. - :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) - """ - if not isinstance(a, torch.Tensor): - a = torch.tensor(a) - - if not isinstance(b, torch.Tensor): - b = torch.tensor(b) - - if len(a.shape) == 1: - a = a.unsqueeze(0) - - if len(b.shape) == 1: - b = b.unsqueeze(0) - - a_norm = torch.nn.functional.normalize(a, p=2, dim=1) - b_norm = torch.nn.functional.normalize(b, p=2, dim=1) - return torch.mm(a_norm, b_norm.transpose(0, 1)) - -def mismatched_sizes_all_gather(tensor: Tensor, group=None, async_op=False, mismatched_axis=0): - # all_gather doesn't support tensor lists where the first dimension is mismatched. This does. - assert torch.distributed.is_initialized(), "torch.distributed not initialized" - world_size = torch.distributed.get_world_size() - # let's get the sizes for everyone - mismatched_sizes = torch.tensor([tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda") - sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] - torch.distributed.all_gather(sizes, mismatched_sizes, group=group, async_op=async_op) - sizes = torch.cat(sizes).cpu().tolist() - # now pad to the max dim-0 size - max_size = max(sizes) - padded = torch.zeros((*tensor.shape[:mismatched_axis], max_size, *tensor.shape[mismatched_axis+1:]), - device=tensor.device, dtype=tensor.dtype) - # selects the place where we're adding information - padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) - padded_to_fill[...] = tensor - # gather the padded tensors - tensor_list = [torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) for _ in range(world_size)] - all_gather_with_grad(tensor_list, padded, group, async_op) - # trim off the padding - for rank in range(world_size): - # checks that the rest is 0 - assert not tensor_list[rank].narrow(mismatched_axis, sizes[rank], padded.shape[mismatched_axis]-sizes[rank]).count_nonzero().is_nonzero(), \ - "This would remove non-padding information" - tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) - return tensor_list \ No newline at end of file From 021acf2c3f658442cd23cb3ff85e8c55d9642536 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Mon, 29 Apr 2024 15:20:02 +0000 Subject: [PATCH 3/9] merge peft option --- llm2vec/llm2vec.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index f9efcd7..bec18b4 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -72,6 +72,7 @@ def from_pretrained( cls, base_model_name_or_path, peft_model_name_or_path=None, + merge_peft=False, enable_bidirectional=True, **kwargs, ): @@ -106,6 +107,8 @@ def from_pretrained( model, peft_model_name_or_path, ) + if merge_peft: + model = model.merge_and_unload() config = {} config_addr = ( From 4f7041273238e5573316738f74fc9df88d8de335 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Mon, 29 Apr 2024 15:20:34 +0000 Subject: [PATCH 4/9] add lora --- experiments/run_mntp.py | 13 ++--- experiments/run_supervised.py | 92 ++++++++++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 14 deletions(-) diff --git a/experiments/run_mntp.py b/experiments/run_mntp.py index 6e02d14..3be42a4 100644 --- a/experiments/run_mntp.py +++ b/experiments/run_mntp.py @@ -108,11 +108,10 @@ def initialize_peft( bias="none", task_type=None, ) - # model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to apply PEFT to the inner model - peft_model = get_peft_model(model.get_model_for_peft(), config) + + model = get_peft_model(model, config) print(f"Model's Lora trainable parameters:") - peft_model.print_trainable_parameters() - model.set_model_for_peft(peft_model) + model.print_trainable_parameters() return model @@ -696,8 +695,10 @@ def main(): low_cpu_mem_usage=model_args.low_cpu_mem_usage, attn_implementation=model_args.attn_implementation, ) - model = initialize_peft( - model, + + # model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to apply PEFT to the inner model + model.model = initialize_peft( + model.model, lora_r=custom_args.lora_r, lora_alpha=2 * custom_args.lora_r, lora_dropout=custom_args.lora_dropout, diff --git a/experiments/run_supervised.py b/experiments/run_supervised.py index 1bddc1e..a6e27f6 100644 --- a/experiments/run_supervised.py +++ b/experiments/run_supervised.py @@ -1,7 +1,8 @@ import argparse import logging from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +import os +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -18,6 +19,8 @@ ) from transformers.trainer_utils import seed_worker +from peft import LoraConfig, get_peft_model + from llm2vec import LLM2Vec from llm2vec.dataset.utils import load_dataset from llm2vec.loss.utils import load_loss @@ -72,6 +75,15 @@ def on_step_end(self, args, state, control, **kwargs): class LLM2VecSupervisedTrainer(Trainer): + def __init__( + self, + *args, + loss_function=None, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.loss_function = loss_function + def compute_loss( self, model: nn.Module, @@ -79,12 +91,21 @@ def compute_loss( return_outputs: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: features, labels = inputs - loss = model(features, labels) + q_reps = self.model(features[0]) + d_reps = self.model(features[1]) + + d_reps_neg = None + if len(features) > 2: + d_reps_neg = self.model(features[2]) + + loss = self.loss_function(q_reps, d_reps, d_reps_neg) + if return_outputs: output = torch.cat( [model(row)["sentence_embedding"][:, None] for row in features], dim=1 ) return loss, output + return loss def get_train_dataloader(self) -> DataLoader: @@ -115,7 +136,56 @@ def get_train_dataloader(self) -> DataLoader: return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + self.model.save(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + + +def initialize_peft( + model, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_modules: Optional[List[str]] = None, +): + if lora_modules is None and model.config.__class__.__name__ in [ + "LlamaConfig", + "MistralConfig", + ]: + lora_modules = [ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + elif lora_modules is None: + raise ValueError("lora_modules must be specified for this model.") + + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_modules, + lora_dropout=lora_dropout, + bias="none", + task_type=None, + ) + + model = get_peft_model(model, config) + print(f"Model's Lora trainable parameters:") + model.print_trainable_parameters() + return model +# TODO: Parse these into JSON, organize same way as MNTP parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default="t5-base") parser.add_argument("--peft_addr", type=str, default=None) @@ -141,7 +211,6 @@ def get_train_dataloader(self) -> DataLoader: parser.add_argument("--experiment_id", default=None, type=str) parser.add_argument("--grad_accumulation_steps", default=1, type=int) parser.add_argument("--lora_r", default=8, type=int) -parser.add_argument("--lora_alpha", default=16, type=int) parser.add_argument("--lora_dropout", default=0.05, type=float) parser.add_argument("--num_cpu_workers", default=4, type=int) parser.add_argument("--bidirectional", default=False, type=str2bool) @@ -231,7 +300,6 @@ def get_train_dataloader(self) -> DataLoader: split="train", file_path=args.dataset_file_path, effective_batch_size=args.train_batch_size * accelerator.num_processes, - shuffle_individual_datasets=args.shuffle_individual_datasets, ) train_examples = [ @@ -243,8 +311,6 @@ def get_train_dataloader(self) -> DataLoader: ) ] - # Load LLM2Vec Model, TODO: Enable bidirectional - model_args = prepare_model_args( bf16=args.bf16, fp16=args.fp16, @@ -253,15 +319,24 @@ def get_train_dataloader(self) -> DataLoader: load_in_4bit=args.load_in_4bit, ) - # TODO: Enable bidirectional, make trainable as an option model = LLM2Vec.from_pretrained( base_model_name_or_path=args.model_name, + enable_bidirectional=args.bidirectional, peft_model_name_or_path=args.peft_addr, + merge_peft=True, pooling_mode=args.pooling_mode, - max_seq_length=args.max_seq_length, + max_length=args.max_seq_length, **model_args, ) + # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model + model.model = initialize_peft( + model.model, + lora_r=args.lora_r, + lora_alpha=2 * args.lora_r, + lora_dropout=args.lora_dropout, + ) + tokenizer = model.tokenizer train_loss = load_loss(args.loss_cls, scale=args.loss_scale) @@ -297,6 +372,7 @@ def get_train_dataloader(self) -> DataLoader: train_dataset=train_examples, data_collator=data_collator, tokenizer=tokenizer, + loss_function=train_loss, ) if args.stop_after_n_steps is not None: From 44df1083829b047add620441cbc1b89e276e1688 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Mon, 29 Apr 2024 15:32:57 +0000 Subject: [PATCH 5/9] prepare for tokenization --- experiments/run_supervised.py | 27 ++++++++++++++++++++++++++- llm2vec/llm2vec.py | 34 +++++++++++++++++++++------------- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/experiments/run_supervised.py b/experiments/run_supervised.py index a6e27f6..72d4ee1 100644 --- a/experiments/run_supervised.py +++ b/experiments/run_supervised.py @@ -16,6 +16,8 @@ TrainingArguments, Trainer, TrainerCallback, + LlamaConfig, + MistralConfig, ) from transformers.trainer_utils import seed_worker @@ -51,7 +53,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: for example in batch: for idx, text in enumerate(example.texts): - text = self.model.prepare_for_tokenization(text) + text = prepare_for_tokenization(model.model, text, pooling_mode=model.pooling_mode) texts[idx].append(text) labels.append(example.label) labels = torch.tensor(labels) @@ -148,6 +150,29 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(self.args, os.path.join(output_dir, "training_args.bin")) +def prepare_for_tokenization(model, text, pooling_mode="mean"): + if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": + text = ( + "<|start_header_id|>user<|end_header_id|>\n\n" + + text.strip() + + "<|eot_id|>" + ) + return text + if model.config._name_or_path in [ + "mistralai/Mistral-7B-Instruct-v0.2", + "meta-llama/Llama-2-7b-chat-hf", + ]: + text = "[INST] " + text.strip() + " [/INST]" + if pooling_mode == "eos_token": + if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(model.config, LlamaConfig) or isinstance( + model.config, MistralConfig + ): + text = text.strip() + " " + + return text + def initialize_peft( model, lora_r: int = 8, diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index bec18b4..ab5a791 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -127,20 +127,26 @@ def from_pretrained( return cls(model=model, tokenizer=tokenizer, **config) def prepare_for_tokenization(self, text): - def _is_instruct(name): - return ( - ("chat" in name.lower()) - or ("instruct" in name.lower()) - or ("sharegpt" in name.lower()) + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": + text = ( + "<|start_header_id|>user<|end_header_id|>\n\n" + + text.strip() + + "<|eot_id|>" ) - - if _is_instruct(self.model.config._name_or_path): + return text + if self.model.config._name_or_path in [ + "mistralai/Mistral-7B-Instruct-v0.2", + "meta-llama/Llama-2-7b-chat-hf", + ]: text = "[INST] " + text.strip() + " [/INST]" - if ( - isinstance(self.model.config, LlamaConfig) - or isinstance(self.model.config, MistralConfig) - ) and self.pooling_mode == "eos_token": - text = text.strip() + " " + if self.pooling_mode == "eos_token": + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(self.model.config, LlamaConfig) or isinstance( + self.model.config, MistralConfig + ): + text = text.strip() + " " + return text def tokenize(self, texts): @@ -408,4 +414,6 @@ def resize_token_embeddings( ) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs + ) From 51a3d8e261316f5a5801a52009d81dc814e6aaf6 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Mon, 29 Apr 2024 16:11:33 +0000 Subject: [PATCH 6/9] assigning config to llm2vec wrapper --- .gitignore | 3 ++- experiments/run_supervised.py | 2 +- llm2vec/llm2vec.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index f4911f8..d2b3239 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ dist/ *.egg-info **/__pycache__ wandb/** -output/** \ No newline at end of file +output/** +cache/** \ No newline at end of file diff --git a/experiments/run_supervised.py b/experiments/run_supervised.py index 72d4ee1..3eb6675 100644 --- a/experiments/run_supervised.py +++ b/experiments/run_supervised.py @@ -53,7 +53,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: for example in batch: for idx, text in enumerate(example.texts): - text = prepare_for_tokenization(model.model, text, pooling_mode=model.pooling_mode) + text = prepare_for_tokenization(model, text, pooling_mode=model.pooling_mode) texts[idx].append(text) labels.append(example.label) labels = torch.tensor(labels) diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index ab5a791..79fffd6 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -53,6 +53,7 @@ def __init__( self.skip_instruction = skip_instruction self.max_length = max_length self.doc_max_length = doc_max_length + self.config = model.config @classmethod def _get_model_class(cls, config_class_name, enable_bidirectional): From 4d5771a2970d8cc74cf1ea59ce9e77c318748bd8 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Tue, 30 Apr 2024 03:22:39 +0000 Subject: [PATCH 7/9] change argeparse, add training config --- experiments/run_supervised.py | 457 +++++++++++++++----------- llm2vec/experiment_utils.py | 51 --- train_configs/supervised/Mistral.json | 29 ++ 3 files changed, 288 insertions(+), 249 deletions(-) create mode 100644 train_configs/supervised/Mistral.json diff --git a/experiments/run_supervised.py b/experiments/run_supervised.py index 3eb6675..2211622 100644 --- a/experiments/run_supervised.py +++ b/experiments/run_supervised.py @@ -1,7 +1,7 @@ -import argparse import logging -from dataclasses import dataclass +from dataclasses import dataclass, field import os +import sys from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -13,11 +13,14 @@ import transformers from transformers import ( + MODEL_FOR_MASKED_LM_MAPPING, + HfArgumentParser, TrainingArguments, Trainer, TrainerCallback, LlamaConfig, MistralConfig, + set_seed, ) from transformers.trainer_utils import seed_worker @@ -26,17 +29,199 @@ from llm2vec import LLM2Vec from llm2vec.dataset.utils import load_dataset from llm2vec.loss.utils import load_loss -from llm2vec.experiment_utils import ( - generate_experiment_id, - log_commandline_args, - set_seed, - str2bool, - prepare_model_args, -) +from llm2vec.experiment_utils import generate_experiment_id + from tqdm import tqdm transformers.logging.set_verbosity_error() +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = get_logger(__name__, log_level="INFO") +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def prepare_for_tokenization(model, text, pooling_mode="mean"): + if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": + text = ( + "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" + ) + return text + if model.config._name_or_path in [ + "mistralai/Mistral-7B-Instruct-v0.2", + "meta-llama/Llama-2-7b-chat-hf", + ]: + text = "[INST] " + text.strip() + " [/INST]" + if pooling_mode == "eos_token": + if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(model.config, LlamaConfig) or isinstance( + model.config, MistralConfig + ): + text = text.strip() + " " + + return text + + +def initialize_peft( + model, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_modules: Optional[List[str]] = None, +): + if lora_modules is None and model.config.__class__.__name__ in [ + "LlamaConfig", + "MistralConfig", + ]: + lora_modules = [ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + elif lora_modules is None: + raise ValueError("lora_modules must be specified for this model.") + + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_modules, + lora_dropout=lora_dropout, + bias="none", + task_type=None, + ) + + model = get_peft_model(model, config) + print(f"Model's Lora trainable parameters:") + model.print_trainable_parameters() + return model + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + peft_model_name_or_path: Optional[str] = field( + default=None, + metadata={"help": ("The PEFT model checkpoint to add on top of base model.")}, + ) + bidirectional: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention." + ) + }, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + ) + }, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + attn_implementation: Optional[str] = field( + default="sdpa", + metadata={ + "help": ("The attention implementation to use in the model."), + "choices": ["eager", "sdpa", "flash_attention_2"], + }, + ) + pooling_mode: Optional[str] = field( + default="mean", + metadata={ + "help": ("The pooling mode to use in the model."), + "choices": ["mean", "weighted_mean", "eos_token"], + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use. Options: E5"}, + ) + dataset_file_path: Optional[str] = field( + default=None, metadata={"help": "The input training data file or folder."} + ) + # TODO: implement this + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + + +@dataclass +class CustomArguments: + """ + Custom arguments for the script + """ + + lora_dropout: float = field( + default=0.05, metadata={"help": "The dropout rate for lora"} + ) + + lora_r: int = field(default=8, metadata={"help": "The r value for lora"}) + + stop_after_n_steps: int = field( + default=10000, metadata={"help": "Stop training after n steps"} + ) + + experiment_id: Optional[str] = field( + default=None, metadata={"help": "The experiment id"} + ) + + loss_class: Optional[str] = field( + default="HardNegativeNLLLoss", + metadata={ + "help": "The loss class to use for training. Options: HardNegativeNLLLoss" + }, + ) + + loss_scale: float = field( + default=50.0, metadata={"help": "The loss scale for the loss function"} + ) + @dataclass class DefaultCollator: @@ -53,7 +238,9 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: for example in batch: for idx, text in enumerate(example.texts): - text = prepare_for_tokenization(model, text, pooling_mode=model.pooling_mode) + text = prepare_for_tokenization( + self.model, text, pooling_mode=self.model.pooling_mode + ) texts[idx].append(text) labels.append(example.label) labels = torch.tensor(labels) @@ -150,118 +337,24 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(self.args, os.path.join(output_dir, "training_args.bin")) -def prepare_for_tokenization(model, text, pooling_mode="mean"): - if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": - text = ( - "<|start_header_id|>user<|end_header_id|>\n\n" - + text.strip() - + "<|eot_id|>" - ) - return text - if model.config._name_or_path in [ - "mistralai/Mistral-7B-Instruct-v0.2", - "meta-llama/Llama-2-7b-chat-hf", - ]: - text = "[INST] " + text.strip() + " [/INST]" - if pooling_mode == "eos_token": - if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": - text = text.strip() + "<|end_of_text|>" - elif isinstance(model.config, LlamaConfig) or isinstance( - model.config, MistralConfig - ): - text = text.strip() + " " - - return text - -def initialize_peft( - model, - lora_r: int = 8, - lora_alpha: int = 16, - lora_dropout: float = 0.05, - lora_modules: Optional[List[str]] = None, -): - if lora_modules is None and model.config.__class__.__name__ in [ - "LlamaConfig", - "MistralConfig", - ]: - lora_modules = [ - "q_proj", - "v_proj", - "k_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ] - elif lora_modules is None: - raise ValueError("lora_modules must be specified for this model.") - - config = LoraConfig( - r=lora_r, - lora_alpha=lora_alpha, - target_modules=lora_modules, - lora_dropout=lora_dropout, - bias="none", - task_type=None, +def main(): + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments) ) - - model = get_peft_model(model, config) - print(f"Model's Lora trainable parameters:") - model.print_trainable_parameters() - return model - -# TODO: Parse these into JSON, organize same way as MNTP -parser = argparse.ArgumentParser() -parser.add_argument("--model_name", type=str, default="t5-base") -parser.add_argument("--peft_addr", type=str, default=None) -parser.add_argument("--dataset_name", type=str, default="E5") -parser.add_argument( - "--dataset_file_path", - type=str, - default="cache/echo-data", -) -parser.add_argument("--output_dir", type=str, default="output") -parser.add_argument("--train_batch_size", default=16, type=int) -parser.add_argument("--dev_batch_size", default=32, type=int) -parser.add_argument("--max_seq_length", default=384, type=int) -parser.add_argument("--epochs", default=40, type=int) -parser.add_argument("--warmup_ratio", default=0.1, type=float) -parser.add_argument("--warmup_steps", default=0, type=int) -parser.add_argument("--checkpoint_save_steps", default=10000, type=int) -parser.add_argument("--lr", default=2e-5, type=float) -parser.add_argument("--find_unused_parameters", default=False, type=str2bool) -parser.add_argument("--seed", default=42, type=int) -parser.add_argument("--pooling_mode", default="mean", type=str) -parser.add_argument("--checkpoint_save_total_limit", default=0, type=int) -parser.add_argument("--experiment_id", default=None, type=str) -parser.add_argument("--grad_accumulation_steps", default=1, type=int) -parser.add_argument("--lora_r", default=8, type=int) -parser.add_argument("--lora_dropout", default=0.05, type=float) -parser.add_argument("--num_cpu_workers", default=4, type=int) -parser.add_argument("--bidirectional", default=False, type=str2bool) -parser.add_argument("--stop_after_n_steps", default=None, type=int) -parser.add_argument("--fp16", default=False, type=str2bool) -parser.add_argument("--bf16", default=False, type=str2bool) -parser.add_argument("--flash_attention_2", default=False, type=str2bool) -parser.add_argument("--load_in_8bit", default=False, type=str2bool) -parser.add_argument("--load_in_4bit", default=False, type=str2bool) -parser.add_argument("--amp", default=False, type=str2bool) -parser.add_argument("--deepspeed", default=None, type=str) -parser.add_argument("--gradient_checkpointing", default=False, type=str2bool) -parser.add_argument("--loss_cls", default="HardNegativeNLLLoss", type=str) -parser.add_argument("--loss_scale", default=50.0, type=float) - -args = parser.parse_args() - -if __name__ == "__main__": - logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, - ) - logger = get_logger(__name__, log_level="INFO") - - if args.find_unused_parameters: + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, custom_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + ( + model_args, + data_args, + training_args, + custom_args, + ) = parser.parse_args_into_dataclasses() + if training_args.ddp_find_unused_parameters: kwargs = [ DistributedDataParallelKwargs( dim=0, @@ -276,55 +369,44 @@ def initialize_peft( kwargs = [] accelerator = Accelerator(kwargs_handlers=kwargs) - set_seed(args) - - # if accelerator.is_main_process: - log_commandline_args(args, logger.info) - - if args.deepspeed: - assert ( - False - ), "DeepSpeed is not implemented yet. There will be a problem with model saving and loading." - - if args.flash_attention_2: - assert args.fp16 or args.bf16 or args.load_in_8bit or args.load_in_4bit + set_seed(training_args.seed) - gradient_checkpointing_kwargs = None - if args.gradient_checkpointing: - gradient_checkpointing_kwargs = {"use_reentrant": False} + if training_args.gradient_checkpointing: + training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} - if args.experiment_id is not None: - experiment_id = args.experiment_id + if custom_args.experiment_id is not None: + experiment_id = custom_args.experiment_id else: experiment_id = generate_experiment_id( - name=args.dataset_name, + name=data_args.dataset_name, split="train", model_name=( - args.model_name - if "/" not in args.model_name - else args.model_name.split("/")[-1] + model_args.model_name_or_path + if "/" not in model_args.model_name_or_path + else model_args.model_name_or_path.split("/")[-1] ), - pooling_mode=args.pooling_mode, - train_batch_size=args.train_batch_size + pooling_mode=model_args.pooling_mode, + train_batch_size=training_args.per_device_train_batch_size * accelerator.num_processes - * args.grad_accumulation_steps, - max_seq_length=args.max_seq_length, - bidirectional=args.bidirectional, - epochs=args.epochs, - seed=args.seed, - warmup_steps=args.warmup_steps, - lr=args.lr, - lora_r=args.lora_r, + * training_args.gradient_accumulation_steps, + max_seq_length=model_args.max_seq_length, + bidirectional=model_args.bidirectional, + epochs=training_args.num_train_epochs, + seed=training_args.seed, + warmup_steps=training_args.warmup_steps, + lr=training_args.learning_rate, + lora_r=custom_args.lora_r, ) - model_save_path = f"{args.output_dir}/{experiment_id}" + training_args.output_dir = f"{training_args.output_dir}/{experiment_id}" # TODO: can also pass separator arg here train_dataset = load_dataset( - args.dataset_name, + data_args.dataset_name, split="train", - file_path=args.dataset_file_path, - effective_batch_size=args.train_batch_size * accelerator.num_processes, + file_path=data_args.dataset_file_path, + effective_batch_size=training_args.per_device_train_batch_size + * accelerator.num_processes, ) train_examples = [ @@ -336,58 +418,33 @@ def initialize_peft( ) ] - model_args = prepare_model_args( - bf16=args.bf16, - fp16=args.fp16, - flash_attention_2=args.flash_attention_2, - load_in_8bit=args.load_in_8bit, - load_in_4bit=args.load_in_4bit, + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) ) - model = LLM2Vec.from_pretrained( - base_model_name_or_path=args.model_name, - enable_bidirectional=args.bidirectional, - peft_model_name_or_path=args.peft_addr, + base_model_name_or_path=model_args.model_name_or_path, + enable_bidirectional=model_args.bidirectional, + peft_model_name_or_path=model_args.peft_model_name_or_path, merge_peft=True, - pooling_mode=args.pooling_mode, - max_length=args.max_seq_length, - **model_args, + pooling_mode=model_args.pooling_mode, + max_length=model_args.max_seq_length, + torch_dtype=torch_dtype, + attn_implementation=model_args.attn_implementation, ) # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model model.model = initialize_peft( model.model, - lora_r=args.lora_r, - lora_alpha=2 * args.lora_r, - lora_dropout=args.lora_dropout, + lora_r=custom_args.lora_r, + lora_alpha=2 * custom_args.lora_r, + lora_dropout=custom_args.lora_dropout, ) tokenizer = model.tokenizer - train_loss = load_loss(args.loss_cls, scale=args.loss_scale) - - training_args = TrainingArguments( - output_dir=model_save_path, - num_train_epochs=args.epochs, - seed=args.seed, - per_device_train_batch_size=args.train_batch_size, - gradient_accumulation_steps=args.grad_accumulation_steps, - learning_rate=args.lr, - warmup_ratio=args.warmup_ratio, - warmup_steps=args.warmup_steps, - logging_dir=model_save_path + "/logs", - logging_steps=50, - save_steps=args.checkpoint_save_steps, - save_total_limit=args.checkpoint_save_total_limit, - remove_unused_columns=False, - disable_tqdm=False, - save_only_model=True, - fp16=args.amp, - gradient_checkpointing=args.gradient_checkpointing, - gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, - deepspeed=args.deepspeed, - ddp_find_unused_parameters=args.find_unused_parameters, - ) + train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale) data_collator = DefaultCollator(model) @@ -400,7 +457,11 @@ def initialize_peft( loss_function=train_loss, ) - if args.stop_after_n_steps is not None: - trainer.add_callback(StopTrainingCallback(args.stop_after_n_steps)) + if custom_args.stop_after_n_steps is not None: + trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps)) trainer.train() + + +if __name__ == "__main__": + main() diff --git a/llm2vec/experiment_utils.py b/llm2vec/experiment_utils.py index 42a022b..c2b97a7 100644 --- a/llm2vec/experiment_utils.py +++ b/llm2vec/experiment_utils.py @@ -1,16 +1,4 @@ -import gzip -import random import re -import shutil -import urllib.request -from pathlib import Path -import numpy as np -import torch -import json -import argparse - -from tqdm.auto import tqdm -from tqdm.utils import CallbackIOWrapper def generate_experiment_id( @@ -98,42 +86,3 @@ def parse_experiment_id(experiment_id): } return result - -def log_commandline_args(args, logger=print): - for arg in vars(args): - logger(f" - {arg}: {getattr(args, arg)}") - -def set_seed(args): - seed = args.seed - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.distributed.get_world_size() > 0: - torch.cuda.manual_seed_all(seed) - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - -def prepare_model_args(**kwargs): - args = {} - for k, v in kwargs.items(): - if k == "flash_attention_2" and v: - args["attn_implementation"] = "flash_attention_2" - if k == "bf16" and v: - args["torch_dtype"] = torch.bfloat16 - if k == "fp16" and v: - args["torch_dtype"] = torch.float16 - if (k == "load_in_8bit" or k == "load_in_4bit") and v: - args[k] = v - if k == "bidirectional": - args[k] = v - - return args diff --git a/train_configs/supervised/Mistral.json b/train_configs/supervised/Mistral.json new file mode 100644 index 0000000..e0859d1 --- /dev/null +++ b/train_configs/supervised/Mistral.json @@ -0,0 +1,29 @@ +{ + "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2", + "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", + "bidirectional": true, + "pooling_mode": "mean", + "dataset_name": "E5", + "dataset_file_path": "cache/echo-data", + "remove_unused_columns": false, + "learning_rate": 2e-4, + "num_train_epochs": 3, + "warmup_steps": 300, + "per_device_train_batch_size": 64, + "per_device_eval_batch_size": 64, + "gradient_accumulation_steps": 1, + "do_train": true, + "disable_tqdm": false, + "max_seq_length": 512, + "overwrite_output_dir": true, + "output_dir": "output/supervised/Mistral-7B-Instruct-v0.2", + "logging_steps": 50, + "save_steps": 200, + "save_only_model": true, + "stop_after_n_steps": 1000, + "lora_r": 16, + "gradient_checkpointing": true, + "torch_dtype": "bfloat16", + "attn_implementation": "flash_attention_2", + "seed": 42 +} \ No newline at end of file From 5492ca5849a0dc3a07063eed60f988940d60df98 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Tue, 30 Apr 2024 22:29:07 +0000 Subject: [PATCH 8/9] bug fix --- llm2vec/models/bidirectional_llama.py | 34 ++++++++++++++++----------- llm2vec/version.py | 2 +- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/llm2vec/models/bidirectional_llama.py b/llm2vec/models/bidirectional_llama.py index 78037fb..2e91c9f 100644 --- a/llm2vec/models/bidirectional_llama.py +++ b/llm2vec/models/bidirectional_llama.py @@ -32,6 +32,14 @@ def is_transformers_attn_greater_or_equal_4_38(): "4.38.0" ) +def is_transformers_attn_greater_or_equal_4_40(): + if not _is_package_available("transformers"): + return False + + return version.parse(importlib.metadata.version("transformers")) >= version.parse( + "4.40.0" + ) + class ModifiedLlamaAttention(LlamaAttention): def __init__(self, *args, **kwargs): @@ -99,12 +107,20 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if is_transformers_attn_greater_or_equal_4_40() and self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -116,7 +132,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) - else cache_position[-1] + 1 + else (cache_position[-1] + 1 if not is_transformers_attn_greater_or_equal_4_40() else past_seen_tokens + sequence_length + 1) ) causal_mask = torch.zeros( @@ -164,19 +180,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): and attention_mask is not None and attention_mask.device.type == "cuda" ): - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(input_tensor, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype ) - if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) return causal_mask diff --git a/llm2vec/version.py b/llm2vec/version.py index 3dc1f76..1276d02 100644 --- a/llm2vec/version.py +++ b/llm2vec/version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.1.5" From a0eb38b6930d30bfdb340fdab0f0f299734e1a27 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Tue, 30 Apr 2024 22:29:19 +0000 Subject: [PATCH 9/9] minor config change --- train_configs/supervised/Mistral.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_configs/supervised/Mistral.json b/train_configs/supervised/Mistral.json index e0859d1..0354f64 100644 --- a/train_configs/supervised/Mistral.json +++ b/train_configs/supervised/Mistral.json @@ -4,7 +4,7 @@ "bidirectional": true, "pooling_mode": "mean", "dataset_name": "E5", - "dataset_file_path": "cache/echo-data", + "dataset_file_path": "cache/E5-data", "remove_unused_columns": false, "learning_rate": 2e-4, "num_train_epochs": 3,