Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement merge method for GroupedGemmLoraLayer to merge LoRA weights into base model #75

Merged
merged 4 commits into from
Nov 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion aria/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Union
from typing import Any, Optional, Union

import torch
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import check_adapters_to_merge
from torch import nn

from aria.model import GroupedGEMM
Expand Down Expand Up @@ -149,3 +150,75 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = result.to(torch_result_dtype)

return result

def merge(
self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> None:
"""
Merge the active adapter weights into the base weights

Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return

for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge:
raise NotImplementedError(
"Safe merge is not supported for GroupedGemmLoraLayer, try not using it instead."
)
else:
delta_weight = self.get_delta_weight(active_adapter)
if not self.use_dora[active_adapter]:
base_layer.weight.data += delta_weight
else:
raise NotImplementedError(
"Dora is not supported for GroupedGemmLoraLayer, try not using it instead."
)

self.merged_adapters.append(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = torch.matmul(weight_A, weight_B) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor
Loading