Skip to content

Commit

Permalink
[Misc][LoRA] Move the implementation of lora bias to punica.py (#10829)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Dec 2, 2024
1 parent a4c4daf commit b45f0d7
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 175 deletions.
60 changes: 27 additions & 33 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)

def generate_and_test(llm, sql_lora_files):
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

Expand All @@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
print("removing lora")


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
generate_and_test(llm, sql_lora_files)


@fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and
Expand Down Expand Up @@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_loras=4,
tensor_parallel_size=4,
)

print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT

print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT

print("removing lora")
generate_and_test(llm, sql_lora_files)


@multi_gpu_test(num_gpus=4)
Expand All @@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
tensor_parallel_size=4,
fully_sharded_loras=True,
)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
generate_and_test(llm, sql_lora_files)

print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):

print("removing lora")
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_lora_bias=True,
)
generate_and_test(llm, sql_lora_files)
41 changes: 12 additions & 29 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def apply(self, x: torch.Tensor,
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output

Expand Down Expand Up @@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
layer.lora_a_stacked[idx], 1.0)

buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]

if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias

layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
layer.lora_b_stacked[idx],
left_offset,
shard_size,
add_input=True,
)
left_offset += shard_size
layer.punica_wrapper.add_expand_packed_nslice(
output,
buffers,
layer.lora_b_stacked,
layer.bias_stacked,
1.0,
layer.output_slices,
)

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
Expand Down Expand Up @@ -234,6 +222,7 @@ def apply(self, x: torch.Tensor,
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_all,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
Expand Down Expand Up @@ -350,15 +339,9 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size

if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias

self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size)
output = output.view(*out_orig_shape)
return output
Expand Down
113 changes: 12 additions & 101 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,63 +67,6 @@ def dec(*args, **kwargs):
return dec


def apply_bias(
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)

bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked

return output.view_as(org_output)


def apply_bias_packed_nslice(
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)

offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias

offset_left += slice

return output.view_as(org_output)


@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
Expand Down Expand Up @@ -311,6 +254,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
bias_all=None,
add_input=True)
return full_output.view_as(full_output_org)

Expand Down Expand Up @@ -399,15 +343,9 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output

def forward(self, input_):
Expand Down Expand Up @@ -576,15 +514,9 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output

def forward(self, input_):
Expand Down Expand Up @@ -687,8 +619,8 @@ def create_lora_weights(
) for _ in range(n_slices))
else:
self.bias_stacked = None

self.output_dim = self.lora_b_stacked[0].shape[2]
self.output_slices = (self.output_dim, self.output_dim)

def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
Expand Down Expand Up @@ -772,17 +704,9 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
(self.output_dim, self.output_dim),
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
(self.output_dim, self.output_dim))
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
return output

@classmethod
Expand Down Expand Up @@ -1129,17 +1053,10 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
self.output_slices,
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked,
self.lora_b_stacked, 1.0,
self.lora_b_stacked,
self.bias_stacked, 1.0,
self.output_slices)
return output

Expand Down Expand Up @@ -1264,15 +1181,9 @@ def set_lora(

def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output

def forward(self, input_):
Expand Down
Loading

0 comments on commit b45f0d7

Please sign in to comment.