Skip to content

Commit

Permalink
feat: add logs + evaluation llama lora
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jan 9, 2025
1 parent 4ff6644 commit d2a401d
Show file tree
Hide file tree
Showing 2 changed files with 453 additions and 24 deletions.
214 changes: 190 additions & 24 deletions src/concrete/ml/torch/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""This module contains classes for LoRA (Low-Rank Adaptation) FHE training and custom layers."""

import logging
from collections import UserDict
from typing import Any, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand All @@ -23,6 +25,40 @@
LINEAR_LAYERS = LINEAR_LAYERS + (TransformerConv1D,)


# pylint: disable=abstract-method
# pylint: disable=arguments-differ


def setup_logger(log_file: str, level=logging.INFO):
"""Set up a logger that logs to both console and a file.
Args:
log_file (str): The path to the log file.
level (int): The logging level.
Returns:
logging.Logger: The logger instance.
"""
logger = logging.getLogger(__name__)
logger.setLevel(level)
logger.handlers.clear()

# Console handler
ch = logging.StreamHandler()
ch.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)

# File handler
fh = logging.FileHandler(log_file, mode="a", encoding="utf-8")
fh.setLevel(level)
fh.setFormatter(formatter)

logger.addHandler(ch)
logger.addHandler(fh)
return logger


# pylint: disable=protected-access
def grad_to(param, device: str) -> None:
"""Move parameter gradient to device.
Expand Down Expand Up @@ -295,11 +331,15 @@ def process_inputs(self, inputs: Any) -> Tuple[Optional[torch.Tensor], Optional[
return attention_mask, labels


# pylint: disable=too-many-instance-attributes
class LoraTrainer:
"""Trainer class for LoRA fine-tuning with FHE support.
This class handles the training loop, optimizer, scheduler,
and integrates with the hybrid model.
This class handles:
- Training loop
- Periodic logging and evaluation
- Loss tracking
- Integration with hybrid FHE model
Args:
model (nn.Module): The base model with LoRA layers to be fine-tuned.
Expand All @@ -310,8 +350,14 @@ class LoraTrainer:
n_layers_to_skip_for_backprop (int): Number of initial linear layers to keep as standard
layers. Since the first layer doesn't need backpropagation (no previous layer to
update), we typically skip 1 layer. Defaults to 1.
eval_loader (DataLoader, optional): DataLoader for evaluation data.
eval_metric_fn (callable, optional): Function(model, eval_loader) -> dict of metrics.
logging_steps (int, optional): Log loss every N training steps. Defaults to 1.
eval_steps (int, optional): Evaluate on eval set every N training steps. Defaults to 10.
train_log_path (str, optional): Path to a log file for training. Defaults to "training.log".
"""

# pylint: disable=too-many-arguments
def __init__(
self,
model,
Expand All @@ -320,13 +366,29 @@ def __init__(
lr_scheduler=None,
training_args=None,
n_layers_to_skip_for_backprop=1,
eval_loader: Optional[DataLoader] = None,
eval_metric_fn: Optional[Callable] = None,
logging_steps: int = 1,
eval_steps: int = 10,
train_log_path: str = "training.log",
checkpoint_dir: str = "checkpoints",
):
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.training_args = training_args or {}
self.gradient_accumulation_steps = self.training_args.get("gradient_accumulation_steps", 1)
self.max_grad_norm = self.training_args.get("max_grad_norm", None)

self.eval_loader = eval_loader
self.eval_metric_fn = eval_metric_fn
self.logging_steps = logging_steps
self.eval_steps = eval_steps
self.training_losses: List[float] = []

# Set up logging
self.logger = setup_logger(train_log_path)
self.logger.info("=== Starting new training session ===")

assert_true(
training_args is None
or "use_cpu" not in training_args
Expand All @@ -351,6 +413,9 @@ def __init__(
self.lora_training_module, module_names=self.remote_names
)

self.checkpoint_dir = checkpoint_dir
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

def compile(self, inputset, n_bits=8):
"""Compile the hybrid model with the given input set.
Expand All @@ -361,13 +426,76 @@ def compile(self, inputset, n_bits=8):
self.lora_training_module.toggle_calibrate(enable=True)
self.hybrid_model.compile_model(inputset, n_bits=n_bits)
self.lora_training_module.toggle_calibrate(enable=False)
self.logger.info("Compilation complete.")

def _evaluate(self, step: int):
if self.eval_loader and self.eval_metric_fn:
self.logger.info("Running evaluation at step %d...", step)
self.lora_training_module.inference_model.eval()
metrics: Dict[str, float] = self.eval_metric_fn(
self.lora_training_module.inference_model, self.eval_loader
)
metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
self.logger.info("[Evaluation at step %d] %s", step, metrics_str)
self.lora_training_module.inference_model.train() # back to train mode
else:
self.logger.info("No evaluation data or metric function provided.")

def save_checkpoint(self, epoch: int, global_step: int):
"""Save a training checkpoint.
Args:
epoch (int): The current epoch number.
global_step (int): The current global step number.
"""
checkpoint_path = Path(self.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pth"
save_dict = {
"model_state_dict": self.lora_training_module.inference_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": (
self.lr_scheduler.state_dict() if self.lr_scheduler else None
),
"training_losses": self.training_losses,
"global_step": global_step,
"epoch": epoch,
}
torch.save(save_dict, checkpoint_path)
self.logger.info("Checkpoint saved at %s", checkpoint_path)

def load_checkpoint(self, checkpoint_path: str):
"""Load a training checkpoint and restore model, optimizer, and lr_scheduler.
Args:
checkpoint_path (str): Path to the checkpoint file.
Returns:
Tuple[int, int]: The epoch and global step of the checkpoint.
Raises:
FileNotFoundError: If the checkpoint file is not found.
"""
if not Path(checkpoint_path).is_file():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.lora_training_module.inference_model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.lr_scheduler and checkpoint["lr_scheduler_state_dict"] is not None:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
self.training_losses = checkpoint.get("training_losses", [])
global_step = checkpoint.get("global_step", 0)
epoch = checkpoint.get("epoch", 0)
self.logger.info(
"Checkpoint loaded from %s (Epoch %d, Step %d)", checkpoint_path, epoch, global_step
)
return epoch, global_step

def train(
self,
train_loader: DataLoader,
num_epochs: int = 10,
fhe: str = "simulate",
device: str = "cpu",
resume_from_checkpoint: Optional[str] = None,
):
"""Train the model using the hybrid FHE model.
Expand All @@ -377,6 +505,7 @@ def train(
fhe (str): FHE mode ('disable', 'simulate', 'execute' or 'torch').
device (str): A device string that is compatible with PyTorch, used for
client-side computations.
resume_from_checkpoint (str, optional): Path to a checkpoint to resume training from.
"""

self.hybrid_model.model.to(device)
Expand All @@ -387,13 +516,23 @@ def train(
# Set the loss scaling factor for gradient accumulation
self.lora_training_module.set_loss_scaling_factor(self.gradient_accumulation_steps)

epoch_pbar = tqdm(range(1, num_epochs + 1), desc="Training", unit="epoch")
start_epoch = 1
global_step = 0

# Optionally resume training
if resume_from_checkpoint:
start_epoch, global_step = self.load_checkpoint(resume_from_checkpoint)
start_epoch += 1 # continue from next epoch

for epoch in epoch_pbar:
for epoch in range(start_epoch, num_epochs + 1):
total_loss = 0.0
self.optimizer.zero_grad() # Zero gradients at the start of the epoch
self.optimizer.zero_grad()

for step, batch in enumerate(train_loader):
epoch_bar = tqdm(
enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}", leave=False
)
for step, batch in epoch_bar:
global_step += 1
if isinstance(batch, (UserDict, dict)):
# Convert dict to tuple of values and move them to the device
batch = {k: v.to(device) for k, v in batch.items()}
Expand All @@ -412,36 +551,62 @@ def train(

# Accumulate loss for logging
total_loss += loss.item()
self.training_losses.append(loss.item())

# Logging
if global_step % self.logging_steps == 0:
avg_loss = total_loss / global_step
self.logger.info(
"Step %d: loss=%f, avg_loss=%f",
global_step,
loss.item(),
avg_loss,
)

# Update weights after gradient accumulation steps
if (step + 1) % self.gradient_accumulation_steps == 0 or (step + 1) == len(
train_loader
# Evaluation
if global_step % self.eval_steps == 0:
self._evaluate(global_step)

# Gradient accumulation steps
if ((step + 1) % self.gradient_accumulation_steps == 0) or (
step + 1 == len(train_loader)
):
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.lora_training_module.parameters(), self.max_grad_norm
)

# Optimizer step
self.optimizer.step()

# Scheduler step
if self.lr_scheduler is not None:
if self.lr_scheduler:
self.lr_scheduler.step()

# Zero gradients
self.optimizer.zero_grad()

avg_loss = total_loss / len(train_loader)
epoch_pbar.set_postfix(
{
"Epoch": epoch,
"Avg Loss": f"{avg_loss:.4f}",
"FHE Mode": fhe,
}
epoch_bar.set_postfix({"loss": f"{loss.item():.4f}"})

avg_epoch_loss = total_loss / len(train_loader)
self.logger.info(
"Epoch %d completed. Avg Loss: %f, FHE Mode: %s",
epoch,
avg_epoch_loss,
fhe,
)

print(f"Training completed. Final Avg Loss: {avg_loss:.4f}, FHE Mode: {fhe}")
# Save checkpoint after each epoch
self.save_checkpoint(epoch, global_step)

self.logger.info(
"Training completed. Final Avg Loss: %f, FHE Mode: %s",
avg_epoch_loss,
fhe,
)

def get_training_losses(self):
"""Return all recorded training losses.
Returns:
List[float]: All recorded training losses.
"""
return self.training_losses

def save_and_clear_private_info(self, path):
"""Save the model and remove private information.
Expand All @@ -450,6 +615,7 @@ def save_and_clear_private_info(self, path):
path (str): The path to save the model.
"""
self.hybrid_model.save_and_clear_private_info(path)
self.logger.info("Model saved at %s", path)


def get_remote_names(model: nn.Module, include_embedding_layers: bool = False) -> List[str]:
Expand Down
Loading

0 comments on commit d2a401d

Please sign in to comment.