Skip to content

Commit

Permalink
dd
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Dec 11, 2024
1 parent 372307f commit d2a25cf
Show file tree
Hide file tree
Showing 4 changed files with 433 additions and 52 deletions.
195 changes: 143 additions & 52 deletions src/concrete/ml/torch/lora.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""This module contains classes for LoRA (Low-Rank Adaptation) FHE training and custom layers."""

from typing import List, Tuple, Union
from pathlib import Path

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import logging
from typing import Optional, Callable, Dict
from .hybrid_backprop_linear import CustomLinear
from .hybrid_model import HybridFHEModel

Expand All @@ -25,6 +27,28 @@
# pylint: disable=arguments-differ


def setup_logger(log_file: str, level=logging.INFO):
"""Set up a logger that logs to both console and a file."""
logger = logging.getLogger(__name__)
logger.setLevel(level)
logger.handlers.clear()

# Console handler
ch = logging.StreamHandler()
ch.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)

# File handler
fh = logging.FileHandler(log_file, mode="a", encoding="utf-8")
fh.setLevel(level)
fh.setFormatter(formatter)

logger.addHandler(ch)
logger.addHandler(fh)
return logger


def try_dict(obj):
"""Try to convert the object to a dict.
Expand Down Expand Up @@ -233,8 +257,11 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non
class LoraTrainer:
"""Trainer class for LoRA fine-tuning with FHE support.
This class handles the training loop, optimizer, scheduler,
and integrates with the hybrid model.
This class handles:
- Training loop
- Periodic logging and evaluation
- Loss tracking
- Integration with hybrid FHE model
Args:
model (nn.Module): The base model with LoRA layers to be fine-tuned.
Expand All @@ -245,6 +272,11 @@ class LoraTrainer:
n_layers_to_skip_for_backprop (int): Number of initial linear layers to keep as standard
layers. Since the first layer doesn't need backpropagation (no previous layer to
update), we typically skip 1 layer. Defaults to 1.
eval_loader (DataLoader, optional): DataLoader for evaluation data.
eval_metric_fn (callable, optional): Function(model, eval_loader) -> dict of metrics.
logging_steps (int, optional): Log loss every N training steps. Defaults to 100.
eval_steps (int, optional): Evaluate on eval set every N training steps. Defaults to 500.
train_log_path (str, optional): Path to a log file for training. Defaults to "training.log".
"""

def __init__(
Expand All @@ -255,14 +287,30 @@ def __init__(
lr_scheduler=None,
training_args=None,
n_layers_to_skip_for_backprop=1,
eval_loader: Optional[DataLoader] = None,
eval_metric_fn: Optional[Callable] = None,
logging_steps: int = 100,
eval_steps: int = 500,
train_log_path: str = "training.log",
checkpoint_dir: str = "checkpoints"
):
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.training_args = training_args or {}
self.gradient_accumulation_steps = self.training_args.get("gradient_accumulation_steps", 1)
self.max_grad_norm = self.training_args.get("max_grad_norm", None)

# Create the LoraTraining module
self.eval_loader = eval_loader
self.eval_metric_fn = eval_metric_fn
self.logging_steps = logging_steps
self.eval_steps = eval_steps
self.training_losses = []

# Set up logging
self.logger = setup_logger(train_log_path)
self.logger.info("=== Starting new training session ===")

# Create the LoraTraining module and hybrid model
self.lora_training_module = LoraTraining(
model, n_layers_to_skip_for_backprop=n_layers_to_skip_for_backprop, loss_fn=loss_fn
)
Expand All @@ -275,6 +323,9 @@ def __init__(
self.lora_training_module, module_names=self.remote_names
)

self.checkpoint_dir = checkpoint_dir
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

def compile(self, inputset, n_bits=8):
"""Compile the hybrid model with the given input set.
Expand All @@ -285,41 +336,79 @@ def compile(self, inputset, n_bits=8):
self.lora_training_module.toggle_calibrate(enable=True)
self.hybrid_model.compile_model(inputset, n_bits=n_bits)
self.lora_training_module.toggle_calibrate(enable=False)
self.logger.info("Compilation complete.")

def _evaluate(self, step: int):
if self.eval_loader and self.eval_metric_fn:
self.logger.info(f"Running evaluation at step {step}...")
self.lora_training_module.inference_model.eval()
metrics: Dict[str, float] = self.eval_metric_fn(self.lora_training_module.inference_model, self.eval_loader)
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
self.logger.info(f"[Evaluation at step {step}] {metrics_str}")
self.lora_training_module.inference_model.train() # back to train mode
else:
self.logger.info("No evaluation data or metric function provided.")

def train(
self,
train_loader: DataLoader,
num_epochs: int = 10,
fhe: str = "simulate",
):
"""Train the model using the hybrid FHE model.
def save_checkpoint(self, epoch: int, global_step: int):
"""Save a training checkpoint.
Args:
train_loader (DataLoader): DataLoader for training data.
num_epochs (int): Number of epochs to train.
fhe (str): FHE mode ('disable', 'simulate', 'execute' or 'torch').
epoch (int): The current epoch number.
global_step (int): The current global step number.
"""
checkpoint_path = Path(self.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
save_dict = {
"model_state_dict": self.lora_training_module.inference_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": self.lr_scheduler.state_dict() if self.lr_scheduler else None,
"training_losses": self.training_losses,
"global_step": global_step,
"epoch": epoch
}
torch.save(save_dict, checkpoint_path)
self.logger.info(f"Checkpoint saved at {checkpoint_path}")

def load_checkpoint(self, checkpoint_path: str):
"""Load a training checkpoint and restore model, optimizer, etc."""
if not Path(checkpoint_path).is_file():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.lora_training_module.inference_model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.lr_scheduler and checkpoint["lr_scheduler_state_dict"] is not None:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
self.training_losses = checkpoint.get("training_losses", [])
global_step = checkpoint.get("global_step", 0)
epoch = checkpoint.get("epoch", 0)
self.logger.info(f"Checkpoint loaded from {checkpoint_path} (Epoch {epoch}, Step {global_step})")
return epoch, global_step

def train(self, train_loader: DataLoader, num_epochs: int = 10, fhe: str = "simulate", resume_from_checkpoint: Optional[str] = None):
device = torch.device("cpu")
self.lora_training_module.to(device)
self.lora_training_module.inference_model.train()

# Set the loss scaling factor for gradient accumulation
self.lora_training_module.set_loss_scaling_factor(self.gradient_accumulation_steps)

epoch_pbar = tqdm(range(1, num_epochs + 1), desc="Training", unit="epoch")
start_epoch = 1
global_step = 0

# Optionally resume training
if resume_from_checkpoint:
start_epoch, global_step = self.load_checkpoint(resume_from_checkpoint)
start_epoch += 1 # continue from next epoch

for epoch in epoch_pbar:
for epoch in range(start_epoch, num_epochs + 1):
total_loss = 0.0
self.optimizer.zero_grad() # Zero gradients at the start of the epoch
self.optimizer.zero_grad()

for step, batch in enumerate(train_loader):
epoch_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", leave=False)
for step, batch in epoch_bar:
global_step += 1

# Convert the batch to a tuple of inputs on the device.
# Move batch to device
if batch_dict := try_dict(batch):
batch = batch_dict
# Convert dict to tuple of values and move them to the device
batch = tuple(
v.to(device) if isinstance(v, torch.Tensor) else v for v in batch.values()
v.to(device) if isinstance(v, torch.Tensor) else v for v in batch_dict.values()
)
elif isinstance(batch, (tuple, list)):
# Move tuple/list elements to the device
Expand All @@ -328,46 +417,48 @@ def train(
for item in batch
)
else:
# If it is a single non-tensor item, wrap it in a tuple
batch = (batch,)

# Forward pass through the hybrid model
loss, _ = self.hybrid_model(batch, fhe=fhe)
total_loss += loss.item()
self.training_losses.append(loss.item())

# Loss scaling and backward is done inside LoraTraining
# Logging
if global_step % self.logging_steps == 0:
avg_loss = total_loss / global_step
self.logger.info(f"Step {global_step}: loss={loss.item():.4f}, avg_loss={avg_loss:.4f}")

# Accumulate loss for logging
total_loss += loss.item()
# Evaluation
if global_step % self.eval_steps == 0:
self._evaluate(global_step)

# Update weights after gradient accumulation steps
if (step + 1) % self.gradient_accumulation_steps == 0 or (step + 1) == len(
train_loader
):
# Gradient accumulation steps
if ((step + 1) % self.gradient_accumulation_steps == 0) or (step + 1 == len(train_loader)):
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.lora_training_module.parameters(), self.max_grad_norm
)

# Optimizer step
torch.nn.utils.clip_grad_norm_(self.lora_training_module.parameters(), self.max_grad_norm)
self.optimizer.step()

# Scheduler step
if self.lr_scheduler is not None:
if self.lr_scheduler:
self.lr_scheduler.step()

# Zero gradients
self.optimizer.zero_grad()

avg_loss = total_loss / len(train_loader)
epoch_pbar.set_postfix(
{
"Epoch": epoch,
"Avg Loss": f"{avg_loss:.4f}",
"FHE Mode": fhe,
}
)
epoch_bar.set_postfix({"loss": f"{loss.item():.4f}"})

avg_epoch_loss = total_loss / len(train_loader)
self.logger.info(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}, FHE Mode: {fhe}")

print(f"Training completed. Final Avg Loss: {avg_loss:.4f}, FHE Mode: {fhe}")
# Save checkpoint after each epoch
self.save_checkpoint(epoch, global_step)

self.logger.info(f"Training completed. Final Avg Loss: {avg_epoch_loss:.4f}, FHE Mode: {fhe}")

def get_training_losses(self):
"""Return all recorded training losses.
Returns:
List[float]: All recorded training losses.
"""
return self.training_losses

def save_and_clear_private_info(self, path):
"""Save the model and remove private information.
Expand All @@ -376,7 +467,7 @@ def save_and_clear_private_info(self, path):
path (str): The path to save the model.
"""
self.hybrid_model.save_and_clear_private_info(path)

self.logger.info(f"Model saved at {path}")

def get_remote_names(model: nn.Module, include_embedding_layers: bool = False) -> List[str]:
"""Get names of modules to be executed remotely.
Expand Down
24 changes: 24 additions & 0 deletions training_log.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
2024-12-11 15:29:49,112 - INFO - === Starting new training session ===
2024-12-11 15:31:25,592 - INFO - Compilation complete.
2024-12-11 15:31:49,795 - INFO - Step 1: loss=1.6456, avg_loss=1.6456
2024-12-11 15:31:49,796 - INFO - Running evaluation at step 1...
2024-12-11 15:33:05,740 - INFO - === Starting new training session ===
2024-12-11 15:35:09,497 - INFO - Compilation complete.
2024-12-11 15:36:14,512 - INFO - Step 1: loss=1.6456, avg_loss=1.6456
2024-12-11 15:36:14,513 - INFO - Running evaluation at step 1...
2024-12-11 15:38:27,575 - INFO - === Starting new training session ===
2024-12-11 15:40:02,969 - INFO - Compilation complete.
2024-12-11 15:40:27,877 - INFO - Step 1: loss=1.6456, avg_loss=1.6456
2024-12-11 15:40:27,878 - INFO - Running evaluation at step 1...
2024-12-11 15:45:36,720 - INFO - === Starting new training session ===
2024-12-11 15:47:08,304 - INFO - Compilation complete.
2024-12-11 15:47:30,295 - INFO - Step 1: loss=1.6456, avg_loss=1.6456
2024-12-11 15:47:30,296 - INFO - Running evaluation at step 1...
2024-12-11 15:48:58,999 - INFO - === Starting new training session ===
2024-12-11 15:50:35,269 - INFO - Compilation complete.
2024-12-11 15:50:57,426 - INFO - Step 1: loss=1.6456, avg_loss=1.6456
2024-12-11 15:50:57,427 - INFO - Running evaluation at step 1...
2024-12-11 16:18:49,085 - INFO - === Starting new training session ===
2024-12-11 16:20:22,945 - INFO - Compilation complete.
2024-12-11 16:20:49,231 - INFO - Step 1: loss=1.6456, avg_loss=1.6456
2024-12-11 16:20:49,232 - INFO - Running evaluation at step 1...
Empty file.
Loading

0 comments on commit d2a25cf

Please sign in to comment.