Skip to content

Commit

Permalink
chore: fix gpt2 custom training + review
Browse files Browse the repository at this point in the history
- fix wrong unpacking of inputs in LoraTraining + add check
- add optimizer step in gpt2
- typo in llama notebook
- update version in requirements
  • Loading branch information
jfrery committed Dec 19, 2024
1 parent 8014ec5 commit abfb76a
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 50 deletions.
38 changes: 28 additions & 10 deletions src/concrete/ml/torch/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,29 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non
ValueError: If the model does not return a loss and no loss function is provided.
"""
assert (
len(inputs) >= 2
len(inputs) >= 2 and len(inputs) <= 3
), "Expected at least two inputs in the tuple: inputs (x) and targets (y)"

# FIXME:
# Remove when hybrid model supports multiple inputs modules
# Unpack model inputs and labels
*model_inputs, y = inputs
# Unpack depending on how many inputs we have
if len(inputs) == 2:
input_ids, labels = inputs
attention_mask = None
else:
input_ids, labels, attention_mask = inputs

# Validate attention mask
assert torch.all(
torch.logical_or(attention_mask == 0, attention_mask == 1)
), "Invalid attention mask provided. Attention mask should only contain 0s and 1s."

if self.loss_fn is None:
# Pass inputs and labels to the model
outputs = self.inference_model(*model_inputs, labels=y)
if attention_mask is not None:
outputs = self.inference_model(
input_ids, labels=labels, attention_mask=attention_mask
)
else:
outputs = self.inference_model(input_ids, labels=labels)

# Check if outputs is a dict and retrieve the loss
if isinstance(outputs, dict):
Expand All @@ -219,10 +231,16 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non
)
else:
# Forward pass without labels; compute loss manually
outputs = self.inference_model(*model_inputs)
if isinstance(outputs, dict) and "logits" in outputs:
outputs = outputs["logits"]
loss = self.loss_fn(outputs, y)
if attention_mask is not None:
logits = self.inference_model(input_ids, attention_mask=attention_mask)
else:
logits = self.inference_model(input_ids)

# If logits is a dict with 'logits' key, extract it
if isinstance(logits, dict) and "logits" in logits:
logits = logits["logits"]

loss = self.loss_fn(logits, labels)

# Scale the loss for gradient accumulation
scaled_loss = loss / self.loss_scaling_factor
Expand Down
58 changes: 50 additions & 8 deletions tests/torch/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,31 +431,73 @@ def test_forward_backward_module():
assert grad_input.shape == x.shape


def test_lora_training_forward_with_loss_fn_and_attention_mask():
"""Test LoraTraining forward using a custom loss_fn and attention_mask."""

class ModelWithAttention(nn.Module):
"""Model that supports attention_mask for testing."""

def __init__(self):
super().__init__()
self.lora_a = nn.Parameter(torch.randn(10, 10))
self.linear = nn.Linear(10, 10)

def forward(self, x, attention_mask=None):
"""Forward pass."""
if attention_mask is not None:
return {"logits": self.linear(x + attention_mask)}
return {"logits": self.linear(x)}

# Define a simple loss function
def simple_loss_fn(logits, labels):
return nn.MSELoss()(logits, labels)

model = ModelWithAttention()

# Instantiate LoraTraining with a custom loss_fn
lora_training = LoraTraining(model, loss_fn=simple_loss_fn)

x = torch.randn(5, 10)
y = torch.randn(5, 10)
attention_mask = torch.randn(5, 10)

# Call forward with (input_ids, labels, attention_mask)
loss, _ = lora_training((x, y, attention_mask))
assert isinstance(loss, torch.Tensor)


def test_lora_training_forward_with_additional_inputs():
"""Test LoraTraining forward with additional inputs."""

class ModelWithAdditionalInputs(nn.Module):
"""Model with additional inputs for testing."""
class ModelWithAttention(nn.Module):
"""Model with attention input for testing."""

def __init__(self):
super().__init__()
self.lora_a = nn.Parameter(torch.randn(10, 10))
self.linear = nn.Linear(10, 10)

def forward(self, x, extra_input, labels=None):
"""Forward pass with additional inputs."""
logits = self.linear(x + extra_input)
def forward(self, x, attention_mask=None, labels=None):
"""Forward pass with an attention mask."""
# Just treat the attention_mask as an extra input
# and add it to x before passing through linear.
if attention_mask is not None:
logits = self.linear(x + attention_mask)
else:
logits = self.linear(x)

if labels is not None:
loss = nn.functional.mse_loss(logits, labels)
return {"loss": loss}
return {"logits": logits}

model = ModelWithAdditionalInputs()
model = ModelWithAttention()
lora_training = LoraTraining(model)
x = torch.randn(5, 10)
y = torch.randn(5, 10)
extra_input = torch.randn(5, 10)
loss, _ = lora_training((x, extra_input, y))
attention_mask = torch.randn(5, 10)

loss, _ = lora_training((x, y, attention_mask))
assert isinstance(loss, torch.Tensor)


Expand Down
36 changes: 15 additions & 21 deletions use_case_examples/lora_finetuning/GPT2FineTuneHybrid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@
"num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)\n",
"max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)\n",
"\n",
"trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)"
"trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)\n",
"\n",
"lr_scheduler = trainer.lr_scheduler\n",
"optimizer = trainer.optimizer"
]
},
{
Expand Down Expand Up @@ -381,27 +384,11 @@
"\n",
" # Training loop\n",
" peft_model.train()\n",
" lora_training.run_optimizer = True\n",
" total_epochs = int(training_args.num_train_epochs)\n",
" epoch_pbar = tqdm(total=total_epochs, desc=\"Training Progress\", position=0)\n",
"\n",
" # Initialize optimizer and scheduler here instead\n",
" optimizer = torch.optim.AdamW(\n",
" hybrid_model.model.parameters(),\n",
" lr=training_args.learning_rate,\n",
" weight_decay=training_args.weight_decay,\n",
" )\n",
"\n",
" num_training_steps = total_epochs * len(train_dataloader)\n",
" lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
" optimizer,\n",
" start_factor=1.0,\n",
" end_factor=0.0,\n",
" total_iters=num_training_steps,\n",
" )\n",
"\n",
" total_batched_samples = 0\n",
" epoch_losses = [] # List to store the loss for each epoch\n",
" epoch_losses = []\n",
"\n",
" # Generate text before the first epoch\n",
" print(\"Generating text before the first epoch:\\n\")\n",
Expand All @@ -415,17 +402,24 @@
" grad_norms = []\n",
"\n",
" for _, batch in enumerate(train_dataloader):\n",
"\n",
" total_batched_samples += 1\n",
"\n",
" batch = {k: v.to(device) for k, v in batch.items()}\n",
"\n",
" # Zero the gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # Forward pass\n",
" loss, grad_norm = hybrid_model(\n",
" (batch[\"input_ids\"], batch[\"labels\"], batch[\"attention_mask\"]), fhe=fhe\n",
" )\n",
"\n",
" total_loss += loss.item()\n",
" # Optimizer step\n",
" optimizer.step()\n",
"\n",
" # Learning rate scheduler step\n",
" lr_scheduler.step()\n",
"\n",
" total_loss += loss.item()\n",
" if grad_norm is not None:\n",
" grad_norms.append(grad_norm)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion use_case_examples/lora_finetuning/LLamaFineTuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@
"outputs": [],
"source": [
"# Save the fine-tuned model\n",
"save_path = Path(\"deployment/gpt2_lora_finetuned\")\n",
"save_path = Path(\"deployment/llama_lora_finetuned\")\n",
"if save_path.is_dir() and any(save_path.iterdir()):\n",
" shutil.rmtree(save_path)\n",
"lora_trainer.save_and_clear_private_info(save_path)\n",
Expand Down
12 changes: 7 additions & 5 deletions use_case_examples/lora_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ This use case demonstrates how to fine-tune language models (GPT-2 and LLaMA) us

Fine-tuning large language models typically requires access to sensitive data, which can raise privacy concerns. By leveraging FHE, we can perform computations on encrypted foundation model weights, ensuring that the data remain private throughout the training process. The LoRA weights are kept in clear on the client side.


## Key Features

- **LoRA Fine-Tuning**: Fine-tune language models by adapting low-rank weights
- **Hybrid Model**: Combine encrypted foundation model weights with clear LoRA weights for optimal performance
- **Low Memory Requirements**: Minimal client-side memory needed for LoRA weights
- **Multiple Approaches**:
- **Multiple Approaches**:
- Custom training implementation for GPT-2
- Simplified API-based approach for LLaMA using the `LoraTrainer`

Expand All @@ -34,19 +33,22 @@ pip install -r requirements.txt

The repository includes two example notebooks:

1. **GPT2FineTuneHybrid.ipynb**:
1. **GPT2FineTuneHybrid.ipynb**:

- Uses a custom training implementation
- Fine-tunes GPT-2 on a small Q&A data-set about FHE
- Shows low-level control over the training process

2. **LLamaFineTuning.ipynb**:
1. **LLamaFineTuning.ipynb**:

- Uses Concrete ML's `LoraTrainer` API for simplified implementation
- Fine-tunes LLaMA on Concrete ML code examples
- Shows how to use the high-level API for encrypted fine-tuning

### Prepare the data-set

Each notebook includes its own data-set:

- GPT-2 uses a small Q&A data-set about FHE in `data_finetune/what_is_fhe.txt`
- LLaMA uses Concrete ML code examples in `data_finetune/data-set.jsonl`

Expand All @@ -67,8 +69,8 @@ In a deployment or production scenario, the model can be fine-tuned as follows:

## Results


### GPT-2 Results

After fine-tuning, the model's weights are distributed between the client and server as follows:

- Total weights removed from the server: 68.24%
Expand Down
10 changes: 5 additions & 5 deletions use_case_examples/lora_finetuning/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
-e ../../.
transformers==4.41.2
peft==0.11.1
transformers==4.46.3
peft==0.12.0
Jinja2==3.1.4
matplotlib==3.7.5
datasets==3.0.1
datasets==3.1.0
accelerate==1.2.0
jupyter==1.0.0
tqdm==4.66.5
jupyter==1.1.1
tqdm==4.67.1

0 comments on commit abfb76a

Please sign in to comment.