Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization/Parameter] WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes #11622

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

cennn
Copy link
Contributor

@cennn cennn commented Dec 30, 2024

FIX: issue-10612, pull-10609

Problem:

Parameter subclasses are not fully compatible with the following aspects:
torch.compile: There are compatibility issues when using parameter subclasses in the context of torch.compile.
offloadedTensor: Parameter subclasses do not work well with tensor subclasses either.

Solution:

Remove all parameter subclasses and instead add the necessary properties and functions directly onto the raw nn.Parameter to achieve the required characteristics for quantization parameters. This approach mainly involves rewriting the code that defines and inherits parameter subclasses in the following way, and it requires minimal modifications to the parts of the code that call these parameter subclasses.

Example Code Changes:

Original Definition:

class PackedvLLMParameter(ModelWeightParameter):
    def __init__(self,
                 packed_factor: Union[int, Fraction],
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
            marlin_tile_size=self.marlin_tile_size)

New Definition:

def PackedvLLMParameter(data: torch.Tensor, **kwargs) -> Parameter:
    param = Parameter(data, requires_grad=False)
    wrap_base_vllm_parameter(param, **kwargs)
    wrap_column_vllm_parameter(param, **kwargs)
    wrap_row_vllm_parameter(param, **kwargs)
    wrap_packed_vllm_parameter(param, **kwargs)
    return param


def wrap_packed_vllm_parameter(param: Parameter,
                               packed_factor: Union[int, Fraction],
                               packed_dim: int,
                               marlin_tile_size: Optional[int] = None,
                               **kwargs) -> None:
    def adjust_shard_indexes_for_packing(shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=packed_factor,
            marlin_tile_size=marlin_tile_size)

    param.packed_factor = packed_factor
    param.packed_dim = packed_dim
    param.marlin_tile_size = marlin_tile_size
    param.adjust_shard_indexes_for_packing = adjust_shard_indexes_for_packing
    add_param_feature(param, Features.Packed)

Unchanged Call Sites:
The parts of the code that call these parameter subclasses do not need to be modified. For example:

qweight = PackedvLLMParameter(
    data=torch.empty(
        input_size_per_partition // self.quant_config.pack_factor,
        output_size_per_partition,
        dtype=torch.int32,
    ),
    input_dim=0,
    output_dim=1,
    packed_dim=0,
    packed_factor=self.quant_config.pack_factor,
    weight_loader=weight_loader)

Verified Tests:

vllm serve Qwen/Qwen2.5-0.5B-Instruct --quantization fp8
vllm serve Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4
vllm serve Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4 --quantization gptq
vllm serve Qwen/Qwen2-1.5B-Instruct-AWQ
vllm serve Qwen/Qwen2-1.5B-Instruct-AWQ --quantization awq
vllm serve nm-testing/tinyllama-oneshot-w4a16-channel-v2
vllm serve nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t
vllm serve nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@cennn cennn changed the title Replace parameter subclasses with raw nn.Parameter with additional attributes WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes Dec 30, 2024
@cennn cennn changed the title WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes [Quantization/Parameter] WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes Dec 30, 2024
@youkaichao
Copy link
Member

As discussed, please fix the format.

@youkaichao
Copy link
Member

@dsikka can you please take a look?

Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the docstrings and typing added for all the original parameters

vllm/model_executor/parameter.py Outdated Show resolved Hide resolved
vllm/model_executor/parameter.py Outdated Show resolved Hide resolved
vllm/model_executor/parameter.py Outdated Show resolved Hide resolved
vllm/model_executor/parameter.py Show resolved Hide resolved
vllm/model_executor/parameter.py Outdated Show resolved Hide resolved
vllm/model_executor/parameter.py Outdated Show resolved Hide resolved
vllm/model_executor/parameter.py Show resolved Hide resolved
vllm/model_executor/parameter.py Show resolved Hide resolved
vllm/model_executor/parameter.py Show resolved Hide resolved
Copy link

mergify bot commented Jan 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cennn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 9, 2025
@mergify mergify bot removed the needs-rebase label Jan 13, 2025
@dsikka
Copy link
Contributor

dsikka commented Jan 14, 2025

Will rereview soon

@youkaichao
Copy link
Member

@dsikka upon reflection, I find this wrap_row_vllm_parameter is quite confusing and tricky for newcomers. I'm thinking about another approach, where we keep the current class, but does not inherit nn.Parameter.

previous code:

            weight_g_idx = RowvLLMParameter(data=torch.empty(
                input_size_per_partition,
                dtype=torch.int32,
            ),
                                            input_dim=0,
                                            weight_loader=weight_loader)

intended:

weight_g_idx = nn.Parameter(torch.empty(
                input_size_per_partition,
                dtype=torch.int32,
            ))
weight_g_idx.vllm_parameter = RowvLLMParameter(data=weight_g_idx, input_dim=0, weight_loader=weight_loader)

this way, the only change is, RowvLLMParameter is not subclass of nn.Parameter.

does this sound better to you?

@dsikka
Copy link
Contributor

dsikka commented Jan 16, 2025

@dsikka upon reflection, I find this wrap_row_vllm_parameter is quite confusing and tricky for newcomers. I'm thinking about another approach, where we keep the current class, but does not inherit nn.Parameter.

previous code:

            weight_g_idx = RowvLLMParameter(data=torch.empty(
                input_size_per_partition,
                dtype=torch.int32,
            ),
                                            input_dim=0,
                                            weight_loader=weight_loader)

intended:

weight_g_idx = nn.Parameter(torch.empty(
                input_size_per_partition,
                dtype=torch.int32,
            ))
weight_g_idx.vllm_parameter = RowvLLMParameter(data=weight_g_idx, input_dim=0, weight_loader=weight_loader)

this way, the only change is, RowvLLMParameter is not subclass of nn.Parameter.

does this sound better to you?

This looks better.
Would RowvLLMParameter inherit from anything or just be stand alone?
You would need to update how the weight loader is called within the weight_loader_v2 methods in linear.py

@cennn
Copy link
Contributor Author

cennn commented Jan 18, 2025

This looks better. Would RowvLLMParameter inherit from anything or just be stand alone? You would need to update how the weight loader is called within the weight_loader_v2 methods in linear.py

@dsikka All quantization class inheritances are kept, except that the BasevLLMParameter does not inherit from nn.Parameter. Here's the PR implemented in this new way (it has already been tested). It's convenient for you to compare the two implementations. You can review it to see if this approach is clearer.

#12158

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tracking torch.compile compatibility with cpu offloading
3 participants