From 022463cfe4d8322597647616277b3335c8b604fc Mon Sep 17 00:00:00 2001 From: Thanh-Nguyen Date: Sun, 17 Nov 2024 04:46:08 +0700 Subject: [PATCH 1/4] Implement merge method for GroupedGemmLoraLayer to merge LoRA weights into base model --- aria/lora/layers.py | 70 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/aria/lora/layers.py b/aria/lora/layers.py index 6c286af..793e469 100644 --- a/aria/lora/layers.py +++ b/aria/lora/layers.py @@ -17,10 +17,12 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Union +from typing import Any, Union, Optional import torch from peft.tuners.lora import LoraLayer +from peft.utils.other import transpose +from peft.tuners.tuners_utils import check_adapters_to_merge from torch import nn from aria.model import GroupedGEMM @@ -149,3 +151,69 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: result = result.to(torch_result_dtype) return result + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + raise NotImplementedError("Safe merge is not supported for GroupedGemmLoraLayer, try not using it instead.") + else: + delta_weight = self.get_delta_weight(active_adapter) + if not self.use_dora[active_adapter]: + base_layer.weight.data += delta_weight + else: + raise NotImplementedError("Dora is not supported for GroupedGemmLoraLayer, try not using it instead.") + + self.merged_adapters.append(active_adapter) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_B[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + output_tensor = torch.matmul(weight_A, weight_B) * self.scaling[adapter] + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor From cb656bf46a835e62f51ad4f6b36d8ac535d4104a Mon Sep 17 00:00:00 2001 From: Thanh-Nguyen Date: Sun, 17 Nov 2024 04:53:20 +0700 Subject: [PATCH 2/4] Remove unused import --- aria/lora/layers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aria/lora/layers.py b/aria/lora/layers.py index 793e469..f7426d7 100644 --- a/aria/lora/layers.py +++ b/aria/lora/layers.py @@ -21,7 +21,6 @@ import torch from peft.tuners.lora import LoraLayer -from peft.utils.other import transpose from peft.tuners.tuners_utils import check_adapters_to_merge from torch import nn From d5fd7a8998b4ed1df16a158404637f440bb0ceb3 Mon Sep 17 00:00:00 2001 From: Thanh-Nguyen Date: Sun, 17 Nov 2024 05:00:39 +0700 Subject: [PATCH 3/4] Fix import ordering --- aria/lora/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aria/lora/layers.py b/aria/lora/layers.py index f7426d7..21a2e73 100644 --- a/aria/lora/layers.py +++ b/aria/lora/layers.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Union, Optional +from typing import Any, Optional, Union import torch from peft.tuners.lora import LoraLayer From 33ca485e451aebdc2ffb696c3e669a0c084b5ffe Mon Sep 17 00:00:00 2001 From: Thanh-Nguyen Date: Sun, 17 Nov 2024 05:05:02 +0700 Subject: [PATCH 4/4] do code formatting --- aria/lora/layers.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/aria/lora/layers.py b/aria/lora/layers.py index 21a2e73..b8f0b6e 100644 --- a/aria/lora/layers.py +++ b/aria/lora/layers.py @@ -151,7 +151,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: return result - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge( + self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> None: """ Merge the active adapter weights into the base weights @@ -173,13 +175,17 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if active_adapter in self.lora_A.keys(): base_layer = self.get_base_layer() if safe_merge: - raise NotImplementedError("Safe merge is not supported for GroupedGemmLoraLayer, try not using it instead.") + raise NotImplementedError( + "Safe merge is not supported for GroupedGemmLoraLayer, try not using it instead." + ) else: delta_weight = self.get_delta_weight(active_adapter) if not self.use_dora[active_adapter]: base_layer.weight.data += delta_weight else: - raise NotImplementedError("Dora is not supported for GroupedGemmLoraLayer, try not using it instead.") + raise NotImplementedError( + "Dora is not supported for GroupedGemmLoraLayer, try not using it instead." + ) self.merged_adapters.append(active_adapter)