diff --git a/requirements.txt b/requirements.txt index b31fbdcb..2d77d1e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ py-cpuinfo # we set this to be above 0a0 so that it doesn't # replace custom pytorch images with the 2.3.0 torch>=2.3.0a0 -transformers>=4.41.2 +transformers>=4.45.2 accelerate>=0.34.2 datasets>=2.15.0 numba diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 05fe4792..51e9752e 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -171,8 +171,9 @@ class TrainingArgs(BaseModel): save_samples: int learning_rate: float warmup_steps: int - is_padding_free: bool random_seed: int = 42 + use_dolomite: bool = False + is_padding_free: bool = False # TODO: deprecate checkpoint_at_epoch: bool = True accelerate_full_state_at_epoch: bool = True diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 4bd7c789..10214e9d 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -199,7 +199,7 @@ def print_masked_samples(data, tokenizer, is_pretrain, num_proc): def get_masked_and_orig_text(sample): labels = sample["labels"] input_ids = sample["input_ids"] - mask_id = get_sp_token(tokenizer, "")[0] + mask_id = get_sp_token(tokenizer, "<|MASK|>")[0] label = [mask_id if tk == -100 else tk for tk in labels] text = tokenizer.decode(label) orig_text = tokenizer.decode(input_ids) @@ -239,7 +239,7 @@ def main(args: DataProcessArgs): # Adding after tokenizer setup as these are temp tokens, not to be saved tokenizer.add_special_tokens( - {"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", ""]} + {"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<|MASK|>"]} ) try: @@ -347,9 +347,26 @@ def main(args: DataProcessArgs): ) # extract only labels and messages formatted into a new dataset - data_with_labels = data_with_labels.select_columns(["labels", "input_ids"]) + data_with_labels = data_with_labels.map( + lambda x: { + "len": len(x["input_ids"]), + }, + num_proc=NUM_PROC, + ) + data_with_labels = data_with_labels.select_columns(["labels", "input_ids", "len"]) + # MASK and both pretrain tokens should not be in the final tokens, those are special tokens added only for data processing purposes. + max_id = len(tokenizer) - 3 + final_valid_data = data_with_labels.filter( + lambda x: all(tk < max_id for tk in x["labels"]), num_proc=NUM_PROC + ) + # Dropping samples that could break training due to oob ids + if len(final_valid_data) < len(data_with_labels): + dropped_samples = len(data_with_labels) - len(final_valid_data) + print( + f"\033[93mWarning: {dropped_samples} samples were dropped because they contained token IDs greater than or equal to {max_id}.\033[0m" + ) # use path to get the stem of the file - data_with_labels.to_json(Path(args.data_output_path) / f"data.jsonl") + final_valid_data.to_json(Path(args.data_output_path) / "data.jsonl") if __name__ == "__main__": diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index d1ae0e01..ab59282f 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -41,8 +41,10 @@ StreamablePopen, add_noisy_embeddings, apply_gradient_checkpointing, + check_flash_attn_enabled, + check_valid_train_args, convert_loss_to_reduce_sum, - ensure_loadable_granite_checkpoint, + ensure_loadable_dolomite_checkpoint, get_projection_layer_names, load_latest_full_state, prepare_peft_model, @@ -84,7 +86,7 @@ def setup_optimizer(args, model): return optimizer -def setup_model(args, tokenizer, train_loader, grad_accum): +def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): bnb_config = None if args.lora_r > 0 and args.lora_quant_bits == 4: # Third Party @@ -102,15 +104,11 @@ def setup_model(args, tokenizer, train_loader, grad_accum): "torch_dtype": torch.bfloat16, "quantization_config": bnb_config, } - if not args.disable_flash_attn: + if flash_enabled: base_model_args["attn_implementation"] = "flash_attention_2" - elif args.is_granite: - raise RuntimeError( - "ERROR: Trying to use padding-free transformer without flash attention is not supported" - ) - if args.is_granite: - with ensure_loadable_granite_checkpoint( + if args.use_dolomite: + with ensure_loadable_dolomite_checkpoint( args.model_name_or_path, args.output_dir ) as path: base_model_args["pretrained_model_name_or_path"] = path @@ -165,9 +163,10 @@ def setup_model(args, tokenizer, train_loader, grad_accum): "Starcoder2ForCausalLM", "GemmaForCausalLM", "MixtralForCausalLM", + "GraniteForCausalLM", ], f"Model class name: {model.__class__.__name__} is not supported." - model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite) + model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite) model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha) # handling of gradient checkpointing @@ -212,15 +211,15 @@ def setup_model(args, tokenizer, train_loader, grad_accum): target_modules=args.lora_target_modules, ) model = prepare_peft_model( - model, peft_config, gradient_checkpointing=not args.is_granite + model, peft_config, gradient_checkpointing=not args.use_dolomite ) - elif not args.is_granite: + elif not args.use_dolomite: model.gradient_checkpointing_enable() # granite gradient checkpointing is handled uniformly # for both lora and full here - if args.is_granite: + if args.use_dolomite: block_name = model._no_split_modules[0] apply_gradient_checkpointing( model, @@ -252,6 +251,9 @@ def make_inputs_require_grad(module, input, output): deepcopy(train_loader), lr_scheduler, ) + # Necessary so that Accelerate does not step once per GPU + # see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69 + lr_scheduler.split_batches = True return model, lr_scheduler, optimizer, accelerator @@ -381,8 +383,8 @@ def train( num_loss_counted_tokens = float( torch.tensor([batch.pop("num_loss_counted_tokens")]) ) - micro_batch_size = float(len(batch["input_ids"])) - if not args.is_granite: + micro_batch_size = float(torch.tensor([batch.pop("num_samples")])) + if not args.use_dolomite: for k in batch: batch[k] = batch[k].to(local_rank) output = model( @@ -453,7 +455,7 @@ def train( "batch_size": int(micro_batch_size), "total_loss": float(log_loss / num_loss_counted_tokens), "samples_seen": samples_seen, - # "gradnorm": global_grad_norm, + "gradnorm": global_grad_norm, # "weight_norm": weight_norm, } ) @@ -535,6 +537,8 @@ def main(args): torch.distributed.all_reduce(tensor) torch.distributed.barrier() + flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite) + dataset = setup_dataset( args.data_path, mock=args.mock_data, @@ -547,7 +551,7 @@ def main(args): avg_sample_len=dataset.get_lengths().mean(), effective_batch_size=args.effective_batch_size, max_batch_len_per_gpu=args.max_batch_len, - is_padding=not args.is_granite, + is_padding=not (args.use_dolomite or flash_enabled), dataset=dataset, seed=args.seed, ) @@ -570,7 +574,8 @@ def main(args): dataset, tokenizer.pad_token_id, num_workers=8, - is_granite=args.is_granite, + use_dolomite=args.use_dolomite, + flash_enabled=flash_enabled, max_batch_len=args.max_batch_len, packing_max_batch_len=packing_max_batch_len, samples_per_gpu=args.samples_per_gpu, @@ -589,7 +594,8 @@ def main(args): dataset, tokenizer.pad_token_id, num_workers=8, - is_granite=args.is_granite, + use_dolomite=args.use_dolomite, + flash_enabled=flash_enabled, max_batch_len=args.max_batch_len, packing_max_batch_len=packing_max_batch_len, samples_per_gpu=args.samples_per_gpu, @@ -613,7 +619,7 @@ def main(args): ) model, lr_scheduler, optimizer, accelerator = setup_model( - args, tokenizer, train_loader, grad_accum + args, tokenizer, train_loader, grad_accum, flash_enabled ) load_latest_full_state(args=args, accelerator=accelerator) @@ -639,11 +645,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ Wrapper around the main training job that calls torchrun. """ - # early validation logic here - if train_args.max_batch_len < train_args.max_seq_len: - raise ValueError( - f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" - ) + check_valid_train_args(train_args) if train_args.process_data: dp.main( @@ -697,14 +699,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.mock_len: command.append(f"--mock_len={train_args.mock_len}") - if train_args.is_padding_free: - command.append("--is_granite") + if train_args.use_dolomite: + command.append("--use_dolomite") if train_args.disable_flash_attn: - if train_args.is_padding_free: - raise RuntimeError( - "ERROR: Trying to use padding-free transformer without flash attention is not supported" - ) command.append("--disable_flash_attn") if train_args.lora: @@ -888,7 +886,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default="SHARD_GRAD_OP", help="Sharding strategy to be used for FSDP distributed training.", ) - parser.add_argument("--is_granite", action="store_true") + parser.add_argument("--use_dolomite", action="store_true") parser.add_argument("--lora_r", type=int, default=0) # set to > 0 to activate lora parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--lora_dropout", type=float, default=0.1) @@ -977,7 +975,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: --save_samples=250000 \ --log_level="INFO" \ --fsdp_sharding_strategy="SHARD_GRAD_OP" \ ---is_granite \ +--use_dolomite \ --max_batch_len 70000 \ --seed=42 """ diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index 9d46607e..fda9a751 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -17,12 +17,15 @@ class TokenDataset(Dataset): def __init__(self, data_path): self.data = load_dataset("json", data_files=data_path, split="train") - self.lengths = np.array( - self.data.map( - lambda x: {"len": len(x["input_ids"])}, - num_proc=8, - )["len"] - ) + if "len" not in self.data.column_names: + self.lengths = np.array( + self.data.map( + lambda x: {"len": len(x["input_ids"])}, + num_proc=8, + )["len"] + ) + else: + self.lengths = np.array(self.data["len"]) def __len__(self): return len(self.data) @@ -87,7 +90,8 @@ def setup_dataloader( dataset: Dataset, pad_token_id: int, num_workers: int = 8, - is_granite=False, + use_dolomite=False, + flash_enabled=True, max_batch_len=60000, packing_max_batch_len=60000, samples_per_gpu=None, @@ -95,7 +99,10 @@ def setup_dataloader( seed=47, ) -> DataLoader: collate_fn = make_collate_fn( - pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len + pad_token_id, + use_dolomite=use_dolomite, + flash_enabled=flash_enabled, + max_batch_len=max_batch_len, ) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) @@ -108,7 +115,7 @@ def setup_dataloader( num_replicas=world_size, rank=rank, seed=seed, - padding=not is_granite, + padding=not flash_enabled, ) sampler = {"batch_sampler": sampler} elif sampler == "distributed": diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 6d79d897..d685d212 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -10,6 +10,7 @@ from typing import Any, List, Optional import importlib import inspect +import json import logging import os import random @@ -40,6 +41,44 @@ import torch import torch.nn.functional as F +# First Party +from instructlab.training.config import TrainingArgs + + +def check_valid_train_args(train_args: TrainingArgs): + # early validation logic here + if train_args.max_batch_len < train_args.max_seq_len: + raise ValueError( + f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" + ) + + if os.path.exists(train_args.model_path): + if not os.path.isdir(train_args.model_path): + raise FileNotFoundError( + "Model path does not appear to be a directory. Please make sure that you're passing a Hugging Face Transformers compatible directory checkpoint." + ) + else: + raise FileNotFoundError( + f"Provided path to model does not exist. Please make sure that you've passed a valid model and that it has appropriate permissions: {train_args.model_path}" + ) + + if train_args.use_dolomite: + with open(Path(train_args.model_path) / "config.json") as conf_json: + model_conf = json.load(conf_json) + if model_conf["model_type"] == "granite": + raise RuntimeError( + "Converting Granite models to Dolomite format is currently unsupported." + ) + if train_args.disable_flash_attn: + raise RuntimeError( + "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported" + ) + + if train_args.is_padding_free: + print( + "\033[33m WARNING: is_padding_free is being deprecated due to adoption of the default padding-free support in Hugging Face Transformers. As such, this flag is non-functional in 0.6.0 and beyond. If you would like to use the older Dolomite padding-free implementation, please set use_dolomite moving forward.\033[0m" + ) + def retrieve_chat_template(chat_tmpl_path): try: @@ -111,9 +150,39 @@ def listen(self): break -def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000): +def supports_flash_attention(device_id=0): + """Check if a GPU supports FlashAttention.""" + major, minor = torch.cuda.get_device_capability(device_id) + # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0) + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0] + is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942") + return is_sm8x or is_sm90 or is_compat_amd + + +def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bool: + if not disable_flash_attn: + if supports_flash_attention(): + flash_enabled = True + else: + raise RuntimeError( + "ERROR: Trying to use Flash Attention on unsupported hardware. Please set disable_flash_attn to True." + ) + elif use_dolomite: + raise RuntimeError( + "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported" + ) + else: + flash_enabled = False + return flash_enabled + + +def make_collate_fn( + pad_token_id, use_dolomite=False, flash_enabled=True, max_batch_len=60000 +): rank = int(os.environ["RANK"]) - if is_granite: + if use_dolomite: def pad_collate_fn(batch): lens = np.array([len(item["input_ids"]) for item in batch]) @@ -140,70 +209,108 @@ def pad_collate_fn(batch): "input_ids": input_ids, "labels": labels, "num_loss_counted_tokens": num_loss_counted_tokens, + "num_samples": len(batch), } else: + if flash_enabled: + + def pad_collate_fn(batch): + input_ids = [] + labels = [] + position_ids = [] + total_len = 0 + num_loss_counted_tokens = 0 + + for num_samples, item in enumerate(batch): + item_len = len(item["input_ids"]) + if total_len + item_len > max_batch_len: + break + + input_ids.extend(item["input_ids"].tolist()) + labels.extend(item["labels"].tolist()) + position_ids.extend(range(total_len, total_len + item_len)) + + total_len += item_len + num_loss_counted_tokens += (item["labels"] != -100).sum().item() + + print( + f"\033[96m total length: {total_len} " + f"num samples {len(batch)} - rank: {rank} " + f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" + ) - def pad_collate_fn(batch): - lens = np.array([len(item["input_ids"]) for item in batch]) - max_len = max(lens) - - input_ids = torch.stack( - [ - F.pad( - item["input_ids"], - (max_len - len(item["input_ids"]), 0), - mode="constant", - value=pad_token_id, - ) - for item in batch - ] - ) - labels = torch.stack( - [ - F.pad( - item["labels"], - (max_len - len(item["labels"]), 0), - mode="constant", - value=-100, - ) - for item in batch - ] - ) - num_loss_counted_tokens = (labels != -100).sum() - - attention_mask = torch.stack( - [ - F.pad( - item["attention_mask"], - (max_len - len(item["attention_mask"]), 0), - mode="constant", - value=0, - ) - for item in batch - ] - ) - print( - f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} " - f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " - f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" - ) + return { + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "labels": torch.tensor([labels], dtype=torch.long), + "position_ids": torch.tensor([position_ids], dtype=torch.long), + "num_loss_counted_tokens": num_loss_counted_tokens, + "num_samples": num_samples + 1, # pylint: disable=W0631 + } - return { - "input_ids": input_ids, - "labels": labels, - "num_loss_counted_tokens": num_loss_counted_tokens, - "attention_mask": attention_mask, - } + else: + + def pad_collate_fn(batch): + lens = np.array([len(item["input_ids"]) for item in batch]) + max_len = max(lens) + + input_ids = torch.stack( + [ + F.pad( + item["input_ids"], + (max_len - len(item["input_ids"]), 0), + mode="constant", + value=pad_token_id, + ) + for item in batch + ] + ) + labels = torch.stack( + [ + F.pad( + item["labels"], + (max_len - len(item["labels"]), 0), + mode="constant", + value=-100, + ) + for item in batch + ] + ) + num_loss_counted_tokens = (labels != -100).sum() + + attention_mask = torch.stack( + [ + F.pad( + item["attention_mask"], + (max_len - len(item["attention_mask"]), 0), + mode="constant", + value=0, + ) + for item in batch + ] + ) + print( + f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} " + f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " + f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" + ) + + return { + "input_ids": input_ids, + "labels": labels, + "num_loss_counted_tokens": num_loss_counted_tokens, + "attention_mask": attention_mask, + "num_samples": len(batch), + } return pad_collate_fn -def convert_loss_to_reduce_sum(model, is_granite=False): +def convert_loss_to_reduce_sum(model, use_dolomite=False): """ this is necessary because multipack changes the samples per gpu, which biases the gradients to be larger for batches with less samples but longer lengths. """ - if is_granite: + if use_dolomite: def get_autoregressive_language_modeling_loss( lm_logits: torch.Tensor, @@ -489,7 +596,7 @@ class UniversalCheckpointArgs: @contextmanager -def ensure_loadable_granite_checkpoint( +def ensure_loadable_dolomite_checkpoint( model_name_or_path: str, tmpdir: str, ): @@ -662,7 +769,7 @@ def save_hf_format_accelerate( tokenizer, accelerator: Accelerator, samples_seen, - convert_granite=True, + convert_dolomite=True, is_lora=False, ): log_rank_0( @@ -672,7 +779,7 @@ def save_hf_format_accelerate( start = time.time() final_output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}" - if args.is_granite and convert_granite: + if args.use_dolomite and convert_dolomite: tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with output_dir = Path(tmpdir.name) else: @@ -694,7 +801,7 @@ def _get_state_dict_patched(model, unwrap=False): model_state = model.module.state_dict() output_dir.mkdir(parents=True, exist_ok=True) - if not model.module.config.architectures and convert_granite: + if not model.module.config.architectures and convert_dolomite: model.module.config.architectures = ["LlamaForCausalLM"] warnings.warn( f"Adding architectures to ckpt: {model.module.config.architectures}", @@ -720,7 +827,7 @@ def _get_state_dict_patched(model, unwrap=False): safe_serialization=True, ) - if args.is_granite and convert_granite and accelerator.is_main_process: + if args.use_dolomite and convert_dolomite and accelerator.is_main_process: # export doesnt like the directory to exist if final_output_dir.exists(): shutil.rmtree(final_output_dir)