Skip to content

Commit

Permalink
feat: add logs + evaluation llama lora
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Dec 17, 2024
1 parent 8014ec5 commit bddfa13
Show file tree
Hide file tree
Showing 2 changed files with 454 additions and 29 deletions.
219 changes: 190 additions & 29 deletions src/concrete/ml/torch/lora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This module contains classes for LoRA (Low-Rank Adaptation) FHE training and custom layers."""

from typing import List, Tuple, Union
import logging
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand All @@ -25,6 +27,36 @@
# pylint: disable=arguments-differ


def setup_logger(log_file: str, level=logging.INFO):
"""Set up a logger that logs to both console and a file.
Args:
log_file (str): The path to the log file.
level (int): The logging level.
Returns:
logging.Logger: The logger instance.
"""
logger = logging.getLogger(__name__)
logger.setLevel(level)
logger.handlers.clear()

# Console handler
ch = logging.StreamHandler()
ch.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)

# File handler
fh = logging.FileHandler(log_file, mode="a", encoding="utf-8")
fh.setLevel(level)
fh.setFormatter(formatter)

logger.addHandler(ch)
logger.addHandler(fh)
return logger


def try_dict(obj):
"""Try to convert the object to a dict.
Expand Down Expand Up @@ -236,11 +268,15 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non
return loss.detach(), None


# pylint: disable=too-many-instance-attributes
class LoraTrainer:
"""Trainer class for LoRA fine-tuning with FHE support.
This class handles the training loop, optimizer, scheduler,
and integrates with the hybrid model.
This class handles:
- Training loop
- Periodic logging and evaluation
- Loss tracking
- Integration with hybrid FHE model
Args:
model (nn.Module): The base model with LoRA layers to be fine-tuned.
Expand All @@ -251,8 +287,14 @@ class LoraTrainer:
n_layers_to_skip_for_backprop (int): Number of initial linear layers to keep as standard
layers. Since the first layer doesn't need backpropagation (no previous layer to
update), we typically skip 1 layer. Defaults to 1.
eval_loader (DataLoader, optional): DataLoader for evaluation data.
eval_metric_fn (callable, optional): Function(model, eval_loader) -> dict of metrics.
logging_steps (int, optional): Log loss every N training steps. Defaults to 1.
eval_steps (int, optional): Evaluate on eval set every N training steps. Defaults to 10.
train_log_path (str, optional): Path to a log file for training. Defaults to "training.log".
"""

# pylint: disable=too-many-arguments
def __init__(
self,
model,
Expand All @@ -261,14 +303,30 @@ def __init__(
lr_scheduler=None,
training_args=None,
n_layers_to_skip_for_backprop=1,
eval_loader: Optional[DataLoader] = None,
eval_metric_fn: Optional[Callable] = None,
logging_steps: int = 1,
eval_steps: int = 10,
train_log_path: str = "training.log",
checkpoint_dir: str = "checkpoints",
):
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.training_args = training_args or {}
self.gradient_accumulation_steps = self.training_args.get("gradient_accumulation_steps", 1)
self.max_grad_norm = self.training_args.get("max_grad_norm", None)

# Create the LoraTraining module
self.eval_loader = eval_loader
self.eval_metric_fn = eval_metric_fn
self.logging_steps = logging_steps
self.eval_steps = eval_steps
self.training_losses: List[float] = []

# Set up logging
self.logger = setup_logger(train_log_path)
self.logger.info("=== Starting new training session ===")

# Create the LoraTraining module and hybrid model
self.lora_training_module = LoraTraining(
model, n_layers_to_skip_for_backprop=n_layers_to_skip_for_backprop, loss_fn=loss_fn
)
Expand All @@ -281,6 +339,9 @@ def __init__(
self.lora_training_module, module_names=self.remote_names
)

self.checkpoint_dir = checkpoint_dir
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

def compile(self, inputset, n_bits=8):
"""Compile the hybrid model with the given input set.
Expand All @@ -291,19 +352,83 @@ def compile(self, inputset, n_bits=8):
self.lora_training_module.toggle_calibrate(enable=True)
self.hybrid_model.compile_model(inputset, n_bits=n_bits)
self.lora_training_module.toggle_calibrate(enable=False)
self.logger.info("Compilation complete.")

def _evaluate(self, step: int):
if self.eval_loader and self.eval_metric_fn:
self.logger.info("Running evaluation at step %d...", step)
self.lora_training_module.inference_model.eval()
metrics: Dict[str, float] = self.eval_metric_fn(
self.lora_training_module.inference_model, self.eval_loader
)
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
self.logger.info("[Evaluation at step %d] %s", step, metrics_str)
self.lora_training_module.inference_model.train() # back to train mode
else:
self.logger.info("No evaluation data or metric function provided.")

def save_checkpoint(self, epoch: int, global_step: int):
"""Save a training checkpoint.
Args:
epoch (int): The current epoch number.
global_step (int): The current global step number.
"""
checkpoint_path = Path(self.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
save_dict = {
"model_state_dict": self.lora_training_module.inference_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": (
self.lr_scheduler.state_dict() if self.lr_scheduler else None
),
"training_losses": self.training_losses,
"global_step": global_step,
"epoch": epoch,
}
torch.save(save_dict, checkpoint_path)
self.logger.info("Checkpoint saved at %s", checkpoint_path)

def load_checkpoint(self, checkpoint_path: str):
"""Load a training checkpoint and restore model, optimizer, and lr_scheduler.
Args:
checkpoint_path (str): Path to the checkpoint file.
Returns:
Tuple[int, int]: The epoch and global step of the checkpoint.
Raises:
FileNotFoundError: If the checkpoint file is not found.
"""
if not Path(checkpoint_path).is_file():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.lora_training_module.inference_model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.lr_scheduler and checkpoint["lr_scheduler_state_dict"] is not None:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
self.training_losses = checkpoint.get("training_losses", [])
global_step = checkpoint.get("global_step", 0)
epoch = checkpoint.get("epoch", 0)
self.logger.info(
"Checkpoint loaded from %s (Epoch %d, Step %d)", checkpoint_path, epoch, global_step
)
return epoch, global_step

def train(
self,
train_loader: DataLoader,
num_epochs: int = 10,
fhe: str = "simulate",
resume_from_checkpoint: Optional[str] = None,
):
"""Train the model using the hybrid FHE model.
Args:
train_loader (DataLoader): DataLoader for training data.
num_epochs (int): Number of epochs to train.
fhe (str): FHE mode ('disable', 'simulate', 'execute' or 'torch').
resume_from_checkpoint (str, optional): Path to a checkpoint to resume training from.
"""
device = torch.device("cpu")
self.lora_training_module.to(device)
Expand All @@ -312,20 +437,29 @@ def train(
# Set the loss scaling factor for gradient accumulation
self.lora_training_module.set_loss_scaling_factor(self.gradient_accumulation_steps)

epoch_pbar = tqdm(range(1, num_epochs + 1), desc="Training", unit="epoch")
start_epoch = 1
global_step = 0

# Optionally resume training
if resume_from_checkpoint:
start_epoch, global_step = self.load_checkpoint(resume_from_checkpoint)
start_epoch += 1 # continue from next epoch

for epoch in epoch_pbar:
for epoch in range(start_epoch, num_epochs + 1):
total_loss = 0.0
self.optimizer.zero_grad() # Zero gradients at the start of the epoch
self.optimizer.zero_grad()

for step, batch in enumerate(train_loader):
epoch_bar = tqdm(
enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", leave=False
)
for step, batch in epoch_bar:
global_step += 1

# Convert the batch to a tuple of inputs on the device.
# Move batch to device
if batch_dict := try_dict(batch):
batch = batch_dict
# Convert dict to tuple of values and move them to the device
batch = tuple(
v.to(device) if isinstance(v, torch.Tensor) else v for v in batch.values()
v.to(device) if isinstance(v, torch.Tensor) else v
for v in batch_dict.values()
)
elif isinstance(batch, (tuple, list)):
# Move tuple/list elements to the device
Expand All @@ -344,36 +478,62 @@ def train(

# Accumulate loss for logging
total_loss += loss.item()
self.training_losses.append(loss.item())

# Logging
if global_step % self.logging_steps == 0:
avg_loss = total_loss / global_step
self.logger.info(
"Step %d: loss=%f, avg_loss=%f",
global_step,
loss.item(),
avg_loss,
)

# Evaluation
if global_step % self.eval_steps == 0:
self._evaluate(global_step)

# Update weights after gradient accumulation steps
if (step + 1) % self.gradient_accumulation_steps == 0 or (step + 1) == len(
train_loader
# Gradient accumulation steps
if ((step + 1) % self.gradient_accumulation_steps == 0) or (
step + 1 == len(train_loader)
):
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.lora_training_module.parameters(), self.max_grad_norm
)

# Optimizer step
self.optimizer.step()

# Scheduler step
if self.lr_scheduler is not None:
if self.lr_scheduler:
self.lr_scheduler.step()

# Zero gradients
self.optimizer.zero_grad()

avg_loss = total_loss / len(train_loader)
epoch_pbar.set_postfix(
{
"Epoch": epoch,
"Avg Loss": f"{avg_loss:.4f}",
"FHE Mode": fhe,
}
epoch_bar.set_postfix({"loss": f"{loss.item():.4f}"})

avg_epoch_loss = total_loss / len(train_loader)
self.logger.info(
"Epoch %d completed. Avg Loss: %f, FHE Mode: %s",
epoch,
avg_epoch_loss,
fhe,
)

print(f"Training completed. Final Avg Loss: {avg_loss:.4f}, FHE Mode: {fhe}")
# Save checkpoint after each epoch
self.save_checkpoint(epoch, global_step)

self.logger.info(
"Training completed. Final Avg Loss: %f, FHE Mode: %s",
avg_epoch_loss,
fhe,
)

def get_training_losses(self):
"""Return all recorded training losses.
Returns:
List[float]: All recorded training losses.
"""
return self.training_losses

def save_and_clear_private_info(self, path):
"""Save the model and remove private information.
Expand All @@ -382,6 +542,7 @@ def save_and_clear_private_info(self, path):
path (str): The path to save the model.
"""
self.hybrid_model.save_and_clear_private_info(path)
self.logger.info("Model saved at %s", path)


def get_remote_names(model: nn.Module, include_embedding_layers: bool = False) -> List[str]:
Expand Down
Loading

0 comments on commit bddfa13

Please sign in to comment.