diff --git a/src/concrete/ml/torch/lora.py b/src/concrete/ml/torch/lora.py index a795f3f1a..35ef429a6 100644 --- a/src/concrete/ml/torch/lora.py +++ b/src/concrete/ml/torch/lora.py @@ -1,7 +1,9 @@ """This module contains classes for LoRA (Low-Rank Adaptation) FHE training and custom layers.""" +import logging from collections import UserDict -from typing import Any, List, Optional, Tuple, Union +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -23,6 +25,40 @@ LINEAR_LAYERS = LINEAR_LAYERS + (TransformerConv1D,) +# pylint: disable=abstract-method +# pylint: disable=arguments-differ + + +def setup_logger(log_file: str, level=logging.INFO): + """Set up a logger that logs to both console and a file. + + Args: + log_file (str): The path to the log file. + level (int): The logging level. + + Returns: + logging.Logger: The logger instance. + """ + logger = logging.getLogger(__name__) + logger.setLevel(level) + logger.handlers.clear() + + # Console handler + ch = logging.StreamHandler() + ch.setLevel(level) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + + # File handler + fh = logging.FileHandler(log_file, mode="a", encoding="utf-8") + fh.setLevel(level) + fh.setFormatter(formatter) + + logger.addHandler(ch) + logger.addHandler(fh) + return logger + + # pylint: disable=protected-access def grad_to(param, device: str) -> None: """Move parameter gradient to device. @@ -295,11 +331,15 @@ def process_inputs(self, inputs: Any) -> Tuple[Optional[torch.Tensor], Optional[ return attention_mask, labels +# pylint: disable=too-many-instance-attributes class LoraTrainer: """Trainer class for LoRA fine-tuning with FHE support. - This class handles the training loop, optimizer, scheduler, - and integrates with the hybrid model. + This class handles: + - Training loop + - Periodic logging and evaluation + - Loss tracking + - Integration with hybrid FHE model Args: model (nn.Module): The base model with LoRA layers to be fine-tuned. @@ -310,8 +350,14 @@ class LoraTrainer: n_layers_to_skip_for_backprop (int): Number of initial linear layers to keep as standard layers. Since the first layer doesn't need backpropagation (no previous layer to update), we typically skip 1 layer. Defaults to 1. + eval_loader (DataLoader, optional): DataLoader for evaluation data. + eval_metric_fn (callable, optional): Function(model, eval_loader) -> dict of metrics. + logging_steps (int, optional): Log loss every N training steps. Defaults to 1. + eval_steps (int, optional): Evaluate on eval set every N training steps. Defaults to 10. + train_log_path (str, optional): Path to a log file for training. Defaults to "training.log". """ + # pylint: disable=too-many-arguments def __init__( self, model, @@ -320,6 +366,12 @@ def __init__( lr_scheduler=None, training_args=None, n_layers_to_skip_for_backprop=1, + eval_loader: Optional[DataLoader] = None, + eval_metric_fn: Optional[Callable] = None, + logging_steps: int = 1, + eval_steps: int = 10, + train_log_path: str = "training.log", + checkpoint_dir: str = "checkpoints", ): self.optimizer = optimizer self.lr_scheduler = lr_scheduler @@ -327,6 +379,16 @@ def __init__( self.gradient_accumulation_steps = self.training_args.get("gradient_accumulation_steps", 1) self.max_grad_norm = self.training_args.get("max_grad_norm", None) + self.eval_loader = eval_loader + self.eval_metric_fn = eval_metric_fn + self.logging_steps = logging_steps + self.eval_steps = eval_steps + self.training_losses: List[float] = [] + + # Set up logging + self.logger = setup_logger(train_log_path) + self.logger.info("=== Starting new training session ===") + assert_true( training_args is None or "use_cpu" not in training_args @@ -351,6 +413,9 @@ def __init__( self.lora_training_module, module_names=self.remote_names ) + self.checkpoint_dir = checkpoint_dir + Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True) + def compile(self, inputset, n_bits=8): """Compile the hybrid model with the given input set. @@ -361,6 +426,68 @@ def compile(self, inputset, n_bits=8): self.lora_training_module.toggle_calibrate(enable=True) self.hybrid_model.compile_model(inputset, n_bits=n_bits) self.lora_training_module.toggle_calibrate(enable=False) + self.logger.info("Compilation complete.") + + def _evaluate(self, step: int): + if self.eval_loader and self.eval_metric_fn: + self.logger.info("Running evaluation at step %d...", step) + self.lora_training_module.inference_model.eval() + metrics: Dict[str, float] = self.eval_metric_fn( + self.lora_training_module.inference_model, self.eval_loader + ) + metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) + self.logger.info("[Evaluation at step %d] %s", step, metrics_str) + self.lora_training_module.inference_model.train() # back to train mode + else: + self.logger.info("No evaluation data or metric function provided.") + + def save_checkpoint(self, epoch: int, global_step: int): + """Save a training checkpoint. + + Args: + epoch (int): The current epoch number. + global_step (int): The current global step number. + """ + checkpoint_path = Path(self.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth" + save_dict = { + "model_state_dict": self.lora_training_module.inference_model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "lr_scheduler_state_dict": ( + self.lr_scheduler.state_dict() if self.lr_scheduler else None + ), + "training_losses": self.training_losses, + "global_step": global_step, + "epoch": epoch, + } + torch.save(save_dict, checkpoint_path) + self.logger.info("Checkpoint saved at %s", checkpoint_path) + + def load_checkpoint(self, checkpoint_path: str): + """Load a training checkpoint and restore model, optimizer, and lr_scheduler. + + Args: + checkpoint_path (str): Path to the checkpoint file. + + Returns: + Tuple[int, int]: The epoch and global step of the checkpoint. + + Raises: + FileNotFoundError: If the checkpoint file is not found. + """ + if not Path(checkpoint_path).is_file(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + self.lora_training_module.inference_model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if self.lr_scheduler and checkpoint["lr_scheduler_state_dict"] is not None: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) + self.training_losses = checkpoint.get("training_losses", []) + global_step = checkpoint.get("global_step", 0) + epoch = checkpoint.get("epoch", 0) + self.logger.info( + "Checkpoint loaded from %s (Epoch %d, Step %d)", checkpoint_path, epoch, global_step + ) + return epoch, global_step def train( self, @@ -368,6 +495,7 @@ def train( num_epochs: int = 10, fhe: str = "simulate", device: str = "cpu", + resume_from_checkpoint: Optional[str] = None, ): """Train the model using the hybrid FHE model. @@ -377,6 +505,7 @@ def train( fhe (str): FHE mode ('disable', 'simulate', 'execute' or 'torch'). device (str): A device string that is compatible with PyTorch, used for client-side computations. + resume_from_checkpoint (str, optional): Path to a checkpoint to resume training from. """ self.hybrid_model.model.to(device) @@ -387,13 +516,23 @@ def train( # Set the loss scaling factor for gradient accumulation self.lora_training_module.set_loss_scaling_factor(self.gradient_accumulation_steps) - epoch_pbar = tqdm(range(1, num_epochs + 1), desc="Training", unit="epoch") + start_epoch = 1 + global_step = 0 + + # Optionally resume training + if resume_from_checkpoint: + start_epoch, global_step = self.load_checkpoint(resume_from_checkpoint) + start_epoch += 1 # continue from next epoch - for epoch in epoch_pbar: + for epoch in range(start_epoch, num_epochs + 1): total_loss = 0.0 - self.optimizer.zero_grad() # Zero gradients at the start of the epoch + self.optimizer.zero_grad() - for step, batch in enumerate(train_loader): + epoch_bar = tqdm( + enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", leave=False + ) + for step, batch in epoch_bar: + global_step += 1 if isinstance(batch, (UserDict, dict)): # Convert dict to tuple of values and move them to the device batch = {k: v.to(device) for k, v in batch.items()} @@ -412,36 +551,62 @@ def train( # Accumulate loss for logging total_loss += loss.item() + self.training_losses.append(loss.item()) + + # Logging + if global_step % self.logging_steps == 0: + avg_loss = total_loss / global_step + self.logger.info( + "Step %d: loss=%f, avg_loss=%f", + global_step, + loss.item(), + avg_loss, + ) - # Update weights after gradient accumulation steps - if (step + 1) % self.gradient_accumulation_steps == 0 or (step + 1) == len( - train_loader + # Evaluation + if global_step % self.eval_steps == 0: + self._evaluate(global_step) + + # Gradient accumulation steps + if ((step + 1) % self.gradient_accumulation_steps == 0) or ( + step + 1 == len(train_loader) ): if self.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_( self.lora_training_module.parameters(), self.max_grad_norm ) - # Optimizer step self.optimizer.step() - - # Scheduler step - if self.lr_scheduler is not None: + if self.lr_scheduler: self.lr_scheduler.step() - - # Zero gradients self.optimizer.zero_grad() - avg_loss = total_loss / len(train_loader) - epoch_pbar.set_postfix( - { - "Epoch": epoch, - "Avg Loss": f"{avg_loss:.4f}", - "FHE Mode": fhe, - } + epoch_bar.set_postfix({"loss": f"{loss.item():.4f}"}) + + avg_epoch_loss = total_loss / len(train_loader) + self.logger.info( + "Epoch %d completed. Avg Loss: %f, FHE Mode: %s", + epoch, + avg_epoch_loss, + fhe, ) - print(f"Training completed. Final Avg Loss: {avg_loss:.4f}, FHE Mode: {fhe}") + # Save checkpoint after each epoch + self.save_checkpoint(epoch, global_step) + + self.logger.info( + "Training completed. Final Avg Loss: %f, FHE Mode: %s", + avg_epoch_loss, + fhe, + ) + + def get_training_losses(self): + """Return all recorded training losses. + + Returns: + List[float]: All recorded training losses. + """ + return self.training_losses def save_and_clear_private_info(self, path): """Save the model and remove private information. @@ -450,6 +615,7 @@ def save_and_clear_private_info(self, path): path (str): The path to save the model. """ self.hybrid_model.save_and_clear_private_info(path) + self.logger.info("Model saved at %s", path) def get_remote_names(model: nn.Module, include_embedding_layers: bool = False) -> List[str]: diff --git a/use_case_examples/lora_finetuning/eval.py b/use_case_examples/lora_finetuning/eval.py new file mode 100644 index 000000000..45c139dd0 --- /dev/null +++ b/use_case_examples/lora_finetuning/eval.py @@ -0,0 +1,263 @@ +import math +import random +import shutil +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from datasets import load_dataset +from peft import LoraConfig, get_peft_model +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) +from utils_lora import generate_and_print + +from concrete.ml.torch.lora import LoraTrainer + +# Set seed for reproducibility +SEED = 0 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + +# Load the model and tokenizer +model_name = "meta-llama/Llama-3.2-1B" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name) + +# Ensure the tokenizer has a pad token +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = model.config.eos_token_id + +# Freeze the original model weights +for param in model.parameters(): + param.requires_grad = False + +# Print initial generation with base model +PROMPT = "What is 2+2?\n" +print("Initial generation with base model:") +print(generate_and_print(PROMPT, model, tokenizer, seed=SEED)) + +# Apply LoRA configuration +peft_config = LoraConfig( + r=8, + lora_alpha=32, + lora_dropout=0.01, + bias="none", + task_type="CAUSAL_LM", + target_modules="all-linear", +) +peft_model = get_peft_model(model, peft_config) + +# Load the dataset +raw_dataset = load_dataset("microsoft/orca-math-word-problems-200k", split="train") + +MAX_LENGTH = 128 + + +def processed(example): + # Combine question and answer into a single prompt + qa_text = example["question"].strip() + "\n" + example["answer"].strip() + tokens = tokenizer( + qa_text, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors=None + ) + + # Discard examples longer than MAX_LENGTH (though truncation is applied) + if len(tokens["input_ids"]) > MAX_LENGTH: + return {} + + # Determine question length + question_tokens = tokenizer( + example["question"].strip(), truncation=True, max_length=MAX_LENGTH, padding=False + ) + question_length = len(question_tokens["input_ids"]) + + # Add the newline length + newline_tokens = tokenizer("\n", add_special_tokens=False)["input_ids"] + question_boundary = question_length + len(newline_tokens) + + # Create labels and mask the question part + labels = tokens["input_ids"].copy() + question_boundary = min(question_boundary, len(labels)) # Safety check + for i in range(question_boundary): + labels[i] = -100 + + tokens["labels"] = labels + return tokens + + +# Apply preprocessing +tokenized_dataset = raw_dataset.map( + processed, batched=False, remove_columns=raw_dataset.column_names +) +tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) > 0) + +# Split into train/test +split_dataset = tokenized_dataset.train_test_split(test_size=0.33, seed=SEED) +train_dataset = split_dataset["train"] +test_dataset = split_dataset["test"] + +data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + +EPOCHS = 10 +PER_DEVICE_TRAIN_BATCH_SIZE = 4 +training_args = TrainingArguments( + output_dir="./checkpoints", + num_train_epochs=EPOCHS, + per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, + gradient_accumulation_steps=1, + save_total_limit=1, + use_cpu=True, + learning_rate=2e-4, + lr_scheduler_type="linear", + seed=SEED, + data_seed=SEED, + warmup_steps=10, + weight_decay=0.01, + prediction_loss_only=True, + # evaluation_strategy is handled by LoraTrainer, not HuggingFace Trainer + report_to="none", +) + + +# Define a causal LM loss function +def causal_lm_loss(logits, labels, ignore_index=-100): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + loss = torch.nn.functional.cross_entropy( + shift_logits, shift_labels, ignore_index=ignore_index, reduction="mean" + ) + return loss + + +# Metric function for evaluation +def metric_fn(model, dataloader): + model.eval() + total_loss = 0.0 + total_tokens = 0 + + print("\nModel response during evaluation:") + print(generate_and_print(PROMPT, model, tokenizer, seed=SEED)) + + progress_bar = tqdm(dataloader, desc="Evaluating", leave=False) + + with torch.no_grad(): + for batch in progress_bar: + input_ids = batch["input_ids"].to("cpu") + labels = batch["labels"].to("cpu") + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)).to("cpu") + + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + # Compute loss as in causal LM + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + valid_positions = shift_labels != -100 + + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + reduction="sum", + ) + total_loss += loss.item() + total_tokens += valid_positions.sum().item() + + current_perplexity = ( + math.exp(total_loss / total_tokens) if total_tokens > 0 else float("inf") + ) + progress_bar.set_postfix({"perplexity": f"{current_perplexity:.2f}"}) + + avg_loss = total_loss / total_tokens if total_tokens > 0 else float("inf") + perplexity = math.exp(avg_loss) + return {"perplexity": perplexity} + + +# Create a HuggingFace Trainer instance to get optimizer and scheduler +hf_trainer = Trainer( + model=peft_model, + args=training_args, + train_dataset=train_dataset, + data_collator=data_collator, +) +train_dataloader = hf_trainer.get_train_dataloader() +hf_trainer.create_optimizer_and_scheduler(num_training_steps=len(train_dataloader) * EPOCHS) + +optimizer = hf_trainer.optimizer +lr_scheduler = hf_trainer.lr_scheduler + +# Prepare input data for calibration +BLOCK_SIZE = MAX_LENGTH +input_tensor = torch.randint( + 0, tokenizer.vocab_size, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long +) +label_tensor = torch.randint( + 0, tokenizer.vocab_size, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long +) +attention_mask = torch.ones((PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long) +inputset = (input_tensor, label_tensor, attention_mask) + +# Prepare eval loader +eval_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=data_collator) + +lora_trainer = LoraTrainer( + model=peft_model, + optimizer=optimizer, + loss_fn=causal_lm_loss, + lr_scheduler=lr_scheduler, + training_args=vars(training_args), + n_layers_to_skip_for_backprop=3, + eval_loader=eval_loader, + eval_metric_fn=metric_fn, + logging_steps=1, + eval_steps=10, + train_log_path="training_log.txt", +) + +# Compile the model with FHE +lora_trainer.compile(inputset, n_bits=16) + +# Train the model using LoraTrainer +print("Starting training using LoraTrainer...") +lora_trainer.train(train_dataloader, num_epochs=EPOCHS, fhe="disable") + +# After training, retrieve all training losses +all_losses = lora_trainer.get_training_losses() +print("Recorded training losses:", all_losses) + +# Evaluate the original model (disabling adapter layers) +peft_model.disable_adapter_layers() +orig_metrics = metric_fn(peft_model, eval_loader) +print("Evaluation on original layers (adapter disabled):", orig_metrics) + +# Evaluate the fine-tuned model (enabling adapter layers) +peft_model.enable_adapter_layers() +finetuned_metrics = metric_fn(peft_model, eval_loader) +print("Evaluation on fine-tuned model (adapter enabled):", finetuned_metrics) + +# Compare generation before and after fine-tuning +peft_model.disable_adapter_layers() +print("Original model generation:") +print(generate_and_print(PROMPT, peft_model, tokenizer, seed=SEED)) + +peft_model.enable_adapter_layers() +print("Fine-tuned model generation:") +print(generate_and_print(PROMPT, peft_model, tokenizer, seed=SEED)) + +# Save the fine-tuned model +save_path = Path("deployment/gpt2_lora_finetuned") +if save_path.is_dir() and any(save_path.iterdir()): + shutil.rmtree(save_path) +lora_trainer.save_and_clear_private_info(save_path) +print("Model saved to:", save_path)