diff --git a/src/concrete/ml/torch/lora.py b/src/concrete/ml/torch/lora.py index 30eec4ef0..d80cdab0a 100644 --- a/src/concrete/ml/torch/lora.py +++ b/src/concrete/ml/torch/lora.py @@ -195,17 +195,29 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non ValueError: If the model does not return a loss and no loss function is provided. """ assert ( - len(inputs) >= 2 + len(inputs) >= 2 and len(inputs) <= 3 ), "Expected at least two inputs in the tuple: inputs (x) and targets (y)" - # FIXME: - # Remove when hybrid model supports multiple inputs modules - # Unpack model inputs and labels - *model_inputs, y = inputs + # Unpack depending on how many inputs we have + if len(inputs) == 2: + input_ids, labels = inputs + attention_mask = None + else: + input_ids, labels, attention_mask = inputs + + # Validate attention mask + assert torch.all( + torch.logical_or(attention_mask == 0, attention_mask == 1) + ), "Invalid attention mask provided. Attention mask should only contain 0s and 1s." if self.loss_fn is None: # Pass inputs and labels to the model - outputs = self.inference_model(*model_inputs, labels=y) + if attention_mask is not None: + outputs = self.inference_model( + input_ids, labels=labels, attention_mask=attention_mask + ) + else: + outputs = self.inference_model(input_ids, labels=labels) # Check if outputs is a dict and retrieve the loss if isinstance(outputs, dict): @@ -219,10 +231,16 @@ def forward(self, inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, Union[Tensor, Non ) else: # Forward pass without labels; compute loss manually - outputs = self.inference_model(*model_inputs) - if isinstance(outputs, dict) and "logits" in outputs: - outputs = outputs["logits"] - loss = self.loss_fn(outputs, y) + if attention_mask is not None: + logits = self.inference_model(input_ids, attention_mask=attention_mask) + else: + logits = self.inference_model(input_ids) + + # If logits is a dict with 'logits' key, extract it + if isinstance(logits, dict) and "logits" in logits: + logits = logits["logits"] + + loss = self.loss_fn(logits, labels) # Scale the loss for gradient accumulation scaled_loss = loss / self.loss_scaling_factor diff --git a/tests/torch/test_lora.py b/tests/torch/test_lora.py index 4f4da5f23..dfdff971b 100644 --- a/tests/torch/test_lora.py +++ b/tests/torch/test_lora.py @@ -431,31 +431,73 @@ def test_forward_backward_module(): assert grad_input.shape == x.shape +def test_lora_training_forward_with_loss_fn_and_attention_mask(): + """Test LoraTraining forward using a custom loss_fn and attention_mask.""" + + class ModelWithAttention(nn.Module): + """Model that supports attention_mask for testing.""" + + def __init__(self): + super().__init__() + self.lora_a = nn.Parameter(torch.randn(10, 10)) + self.linear = nn.Linear(10, 10) + + def forward(self, x, attention_mask=None): + """Forward pass.""" + if attention_mask is not None: + return {"logits": self.linear(x + attention_mask)} + return {"logits": self.linear(x)} + + # Define a simple loss function + def simple_loss_fn(logits, labels): + return nn.MSELoss()(logits, labels) + + model = ModelWithAttention() + + # Instantiate LoraTraining with a custom loss_fn + lora_training = LoraTraining(model, loss_fn=simple_loss_fn) + + x = torch.randn(5, 10) + y = torch.randn(5, 10) + attention_mask = torch.randn(5, 10) + + # Call forward with (input_ids, labels, attention_mask) + loss, _ = lora_training((x, y, attention_mask)) + assert isinstance(loss, torch.Tensor) + + def test_lora_training_forward_with_additional_inputs(): """Test LoraTraining forward with additional inputs.""" - class ModelWithAdditionalInputs(nn.Module): - """Model with additional inputs for testing.""" + class ModelWithAttention(nn.Module): + """Model with attention input for testing.""" def __init__(self): super().__init__() self.lora_a = nn.Parameter(torch.randn(10, 10)) self.linear = nn.Linear(10, 10) - def forward(self, x, extra_input, labels=None): - """Forward pass with additional inputs.""" - logits = self.linear(x + extra_input) + def forward(self, x, attention_mask=None, labels=None): + """Forward pass with an attention mask.""" + # Just treat the attention_mask as an extra input + # and add it to x before passing through linear. + if attention_mask is not None: + logits = self.linear(x + attention_mask) + else: + logits = self.linear(x) + if labels is not None: loss = nn.functional.mse_loss(logits, labels) return {"loss": loss} return {"logits": logits} - model = ModelWithAdditionalInputs() + model = ModelWithAttention() lora_training = LoraTraining(model) x = torch.randn(5, 10) y = torch.randn(5, 10) - extra_input = torch.randn(5, 10) - loss, _ = lora_training((x, extra_input, y)) + attention_mask = torch.randn(5, 10) + + loss, _ = lora_training((x, y, attention_mask)) assert isinstance(loss, torch.Tensor) diff --git a/use_case_examples/lora_finetuning/GPT2FineTuneHybrid.ipynb b/use_case_examples/lora_finetuning/GPT2FineTuneHybrid.ipynb index c229016a0..51a99b21b 100644 --- a/use_case_examples/lora_finetuning/GPT2FineTuneHybrid.ipynb +++ b/use_case_examples/lora_finetuning/GPT2FineTuneHybrid.ipynb @@ -309,7 +309,10 @@ "num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)\n", "max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)\n", "\n", - "trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)" + "trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)\n", + "\n", + "lr_scheduler = trainer.lr_scheduler\n", + "optimizer = trainer.optimizer" ] }, { @@ -381,27 +384,11 @@ "\n", " # Training loop\n", " peft_model.train()\n", - " lora_training.run_optimizer = True\n", " total_epochs = int(training_args.num_train_epochs)\n", " epoch_pbar = tqdm(total=total_epochs, desc=\"Training Progress\", position=0)\n", "\n", - " # Initialize optimizer and scheduler here instead\n", - " optimizer = torch.optim.AdamW(\n", - " hybrid_model.model.parameters(),\n", - " lr=training_args.learning_rate,\n", - " weight_decay=training_args.weight_decay,\n", - " )\n", - "\n", - " num_training_steps = total_epochs * len(train_dataloader)\n", - " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", - " optimizer,\n", - " start_factor=1.0,\n", - " end_factor=0.0,\n", - " total_iters=num_training_steps,\n", - " )\n", - "\n", " total_batched_samples = 0\n", - " epoch_losses = [] # List to store the loss for each epoch\n", + " epoch_losses = []\n", "\n", " # Generate text before the first epoch\n", " print(\"Generating text before the first epoch:\\n\")\n", @@ -415,17 +402,24 @@ " grad_norms = []\n", "\n", " for _, batch in enumerate(train_dataloader):\n", - "\n", " total_batched_samples += 1\n", - "\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", "\n", + " # Zero the gradients\n", + " optimizer.zero_grad()\n", + "\n", + " # Forward pass\n", " loss, grad_norm = hybrid_model(\n", " (batch[\"input_ids\"], batch[\"labels\"], batch[\"attention_mask\"]), fhe=fhe\n", " )\n", "\n", - " total_loss += loss.item()\n", + " # Optimizer step\n", + " optimizer.step()\n", "\n", + " # Learning rate scheduler step\n", + " lr_scheduler.step()\n", + "\n", + " total_loss += loss.item()\n", " if grad_norm is not None:\n", " grad_norms.append(grad_norm)\n", "\n", diff --git a/use_case_examples/lora_finetuning/LLamaFineTuning.ipynb b/use_case_examples/lora_finetuning/LLamaFineTuning.ipynb index ab29c525f..7ee8a6810 100644 --- a/use_case_examples/lora_finetuning/LLamaFineTuning.ipynb +++ b/use_case_examples/lora_finetuning/LLamaFineTuning.ipynb @@ -326,7 +326,7 @@ "outputs": [], "source": [ "# Save the fine-tuned model\n", - "save_path = Path(\"deployment/gpt2_lora_finetuned\")\n", + "save_path = Path(\"deployment/llama_lora_finetuned\")\n", "if save_path.is_dir() and any(save_path.iterdir()):\n", " shutil.rmtree(save_path)\n", "lora_trainer.save_and_clear_private_info(save_path)\n", diff --git a/use_case_examples/lora_finetuning/README.md b/use_case_examples/lora_finetuning/README.md index 36ae88b3e..a1513298f 100644 --- a/use_case_examples/lora_finetuning/README.md +++ b/use_case_examples/lora_finetuning/README.md @@ -6,13 +6,12 @@ This use case demonstrates how to fine-tune language models (GPT-2 and LLaMA) us Fine-tuning large language models typically requires access to sensitive data, which can raise privacy concerns. By leveraging FHE, we can perform computations on encrypted foundation model weights, ensuring that the data remain private throughout the training process. The LoRA weights are kept in clear on the client side. - ## Key Features - **LoRA Fine-Tuning**: Fine-tune language models by adapting low-rank weights - **Hybrid Model**: Combine encrypted foundation model weights with clear LoRA weights for optimal performance - **Low Memory Requirements**: Minimal client-side memory needed for LoRA weights -- **Multiple Approaches**: +- **Multiple Approaches**: - Custom training implementation for GPT-2 - Simplified API-based approach for LLaMA using the `LoraTrainer` @@ -34,12 +33,14 @@ pip install -r requirements.txt The repository includes two example notebooks: -1. **GPT2FineTuneHybrid.ipynb**: +1. **GPT2FineTuneHybrid.ipynb**: + - Uses a custom training implementation - Fine-tunes GPT-2 on a small Q&A data-set about FHE - Shows low-level control over the training process -2. **LLamaFineTuning.ipynb**: +1. **LLamaFineTuning.ipynb**: + - Uses Concrete ML's `LoraTrainer` API for simplified implementation - Fine-tunes LLaMA on Concrete ML code examples - Shows how to use the high-level API for encrypted fine-tuning @@ -47,6 +48,7 @@ The repository includes two example notebooks: ### Prepare the data-set Each notebook includes its own data-set: + - GPT-2 uses a small Q&A data-set about FHE in `data_finetune/what_is_fhe.txt` - LLaMA uses Concrete ML code examples in `data_finetune/data-set.jsonl` @@ -67,8 +69,8 @@ In a deployment or production scenario, the model can be fine-tuned as follows: ## Results - ### GPT-2 Results + After fine-tuning, the model's weights are distributed between the client and server as follows: - Total weights removed from the server: 68.24% diff --git a/use_case_examples/lora_finetuning/requirements.txt b/use_case_examples/lora_finetuning/requirements.txt index e99a87ffe..da6495fef 100644 --- a/use_case_examples/lora_finetuning/requirements.txt +++ b/use_case_examples/lora_finetuning/requirements.txt @@ -1,9 +1,9 @@ -e ../../. -transformers==4.41.2 -peft==0.11.1 +transformers==4.46.3 +peft==0.12.0 Jinja2==3.1.4 matplotlib==3.7.5 -datasets==3.0.1 +datasets==3.1.0 accelerate==1.2.0 -jupyter==1.0.0 -tqdm==4.66.5 \ No newline at end of file +jupyter==1.1.1 +tqdm==4.67.1 \ No newline at end of file