diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index aae6310a2a213..d3ca7f878191a 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts -@fork_new_process_for_each_test -def test_llama_lora(sql_lora_files): - - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1) - +def generate_and_test(llm, sql_lora_files): print("lora adapter created") assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT @@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files): print("removing lora") +@fork_new_process_for_each_test +def test_llama_lora(sql_lora_files): + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1) + generate_and_test(llm, sql_lora_files) + + @fork_new_process_for_each_test def test_llama_lora_warmup(sql_lora_files): """Test that the LLM initialization works with a warmup LORA path and @@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files): max_loras=4, tensor_parallel_size=4, ) - - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT - - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT - - print("removing lora") + generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) @@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): tensor_parallel_size=4, fully_sharded_loras=True, ) - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + generate_and_test(llm, sql_lora_files) - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): - print("removing lora") + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + fully_sharded_loras=True, + enable_lora_bias=True, + ) + generate_and_test(llm, sql_lora_files) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index f5c2eced9d2bb..5f2d32defe030 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -73,6 +73,7 @@ def apply(self, x: torch.Tensor, self.punica_wrapper.add_expand(output, buffer, self.lora_b_stacked, + self.bias_stacked, add_input=True) # now have column partitioned output @@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): layer.lora_a_stacked[idx], 1.0) buffers = tensor_model_parallel_all_gather(buffers) - left_offset = 0 - for idx in range(n): - shard_size = layer.lora_b_stacked[idx].shape[2] - - if layer.bias_stacked is not None: - bias = layer.bias_stacked[idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[layer.punica_wrapper.token_lora_indices] - bias[layer.punica_wrapper.token_lora_indices == -1] = 0 - output[:, left_offset:left_offset + shard_size] += bias - - layer.punica_wrapper.add_expand_slice( - output, - buffers[idx], - layer.lora_b_stacked[idx], - left_offset, - shard_size, - add_input=True, - ) - left_offset += shard_size + layer.punica_wrapper.add_expand_packed_nslice( + output, + buffers, + layer.lora_b_stacked, + layer.bias_stacked, + 1.0, + layer.output_slices, + ) output = output.view(*out_orig_shape) # now have column partitioned and packed output @@ -234,6 +222,7 @@ def apply(self, x: torch.Tensor, self.punica_wrapper.add_expand(output, buffer, self.lora_b_stacked, + self.bias_all, add_input=True) # now have column partitioned output output = output.view(*out_orig_shape) @@ -350,15 +339,9 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # reduced before being used shard_size = self.lora_b_stacked.shape[2] start_idx = self.tp_rank * shard_size - - if self.bias_stacked is not None: - bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1]) - bias = bias[self.punica_wrapper.token_lora_indices] - bias[self.punica_wrapper.token_lora_indices == -1] = 0 - output += bias - self.punica_wrapper.add_expand_slice(output, buffer, - self.lora_b_stacked, start_idx, + self.lora_b_stacked, + self.bias_stacked, start_idx, shard_size) output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3701988ff692f..73748b5ce511e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -67,63 +67,6 @@ def dec(*args, **kwargs): return dec -def apply_bias( - indices: torch.Tensor, - output: torch.Tensor, - bias_stacked: torch.Tensor, -): - """Applies bias to output - - Input shapes: - bias_stacked: (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, output_dim) - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) - bias_stacked = bias_stacked[indices] - bias_stacked[indices == -1] = 0 - output += bias_stacked - - return output.view_as(org_output) - - -def apply_bias_packed_nslice( - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], - bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], -): - """Applies bias to output - - Input shapes: - bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - - offset_left += slice - - return output.view_as(org_output) - - @dataclass class LoRAMapping(AdapterMapping): is_prefill: bool = False @@ -311,6 +254,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.punica_wrapper.add_expand(full_output, full_lora_a_embeddings, self.lora_b_stacked, + bias_all=None, add_input=True) return full_output.view_as(full_output_org) @@ -399,15 +343,9 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.lora_b_stacked, self.bias_stacked, + 1.0) return output def forward(self, input_): @@ -576,15 +514,9 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.lora_b_stacked, self.bias_stacked, + 1.0) return output def forward(self, input_): @@ -687,8 +619,8 @@ def create_lora_weights( ) for _ in range(n_slices)) else: self.bias_stacked = None - self.output_dim = self.lora_b_stacked[0].shape[2] + self.output_slices = (self.output_dim, self.output_dim) def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -772,17 +704,9 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias_packed_nslice( - self.indices, - output, - (self.output_dim, self.output_dim), - self.bias_stacked, - ) self.punica_wrapper.add_lora_packed_nslice( - output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, - (self.output_dim, self.output_dim)) + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.bias_stacked, 1.0, (self.output_dim, self.output_dim)) return output @classmethod @@ -1129,17 +1053,10 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias_packed_nslice( - self.indices, - output, - self.output_slices, - self.bias_stacked, - ) self.punica_wrapper.add_lora_packed_nslice(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0, + self.lora_b_stacked, + self.bias_stacked, 1.0, self.output_slices) return output @@ -1264,15 +1181,9 @@ def set_lora( def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.lora_b_stacked, self.bias_stacked, + 1.0) return output def forward(self, input_): diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 082041f390750..3f775b7ba363e 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -450,6 +450,62 @@ def expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) + def apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + bias_stacked: torch.Tensor, + ): + """Applies bias to output + + Input shapes: + bias_stacked: (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) + bias_stacked = bias_stacked[indices] + bias_stacked[indices == -1] = 0 + output += bias_stacked + + return output.view_as(org_output) + + def apply_bias_packed_nslice( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + def add_shrink( self, y: torch.Tensor, @@ -474,16 +530,19 @@ def add_expand( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + bias_all: Optional[torch.Tensor], add_input: bool = True, ): """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the + Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the GEMM of lora'b. When `is_prefill` is true, it indicates that it is currently the prefill stage, and the `expand_prefill` function should be called. Otherwise, it is the decode stage, and the expand_decode function should be called. """ + if bias_all is not None: + y = self.apply_bias(self.token_lora_indices, y, bias_all) expand_fun: Callable = (self.expand_prefill if self.is_prefill else self.expand_decode) @@ -493,23 +552,54 @@ def add_expand_slice(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + bias_all: Optional[torch.Tensor], y_offset: Optional[int], y_slice_size: Optional[int], add_input: bool = True): """ Similar to `add_expand` """ + if bias_all is not None: + y = self.apply_bias(self.token_lora_indices, y, bias_all) expand_slice_fun: Callable = (self.expand_slice_prefill if self.is_prefill else self.expand_slice_decode) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_b_stacked: Tuple[torch.Tensor, ...], + bias_stacked: Optional[Tuple[torch.Tensor, + ...]], + scale: float, + output_slices: Tuple[int, ...]) -> None: + """ + Similar to `add_expand` + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + if bias_stacked is not None: + self.apply_bias_packed_nslice(self.token_lora_indices, y, + output_slices, bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self.add_expand_slice(y, + x[slice_idx], + lora_b_stacked[slice_idx], + None, + offset_left, + output_slices[slice_idx], + add_input=True) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + def add_lora(self, y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, wb_t_all: torch.Tensor, + bias_all: Optional[torch.Tensor], scale: float, y_offset: Optional[int] = None, y_slice_size: Optional[int] = None, @@ -522,12 +612,13 @@ def add_lora(self, @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) * scale - ).squeeze(0) + ).squeeze(0)+bias[i] Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor wa_t_all (torch.Tensor): lora_a's weight wb_t_all (torch.Tensor): lora_b's weight + bias_all: (torch.Tensor): lora's bias scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. @@ -544,27 +635,26 @@ def add_lora(self, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - + if bias_all is not None: + y = self.apply_bias(self.token_lora_indices, y, bias_all) self.add_shrink(buffer, x, wa_t_all, scale) if y_offset is None and y_slice_size is None: - self.add_expand(y, buffer, wb_t_all, add_input=True) + self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True) else: self.add_expand_slice(y, buffer, wb_t_all, + None, y_offset, y_slice_size, add_input=True) y = y.view_as(y_org) def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - scale: float, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + bias_all: Tuple[Optional[torch.Tensor], + ...], scale: float, output_slices: Tuple[int, ...]) -> None: """ Applies lora to each input. Similar to add_lora, This method is @@ -575,10 +665,13 @@ def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, x = x.view(-1, x.shape[-1]) y = y.view(-1, y.shape[-1]) offset_left = 0 + if bias_all is not None: + y = self.apply_bias_packed_nslice(self.token_lora_indices, y, + output_slices, bias_all) # TODO fuse these kernels for slice_idx in range(len(output_slices)): self.add_lora(y, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], scale, offset_left, + lora_b_stacked[slice_idx], None, scale, offset_left, output_slices[slice_idx]) offset_left += output_slices[slice_idx]