Skip to content

Commit

Permalink
revise the code according to the reviewer's feedback
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
  • Loading branch information
xffxff committed Nov 21, 2024
1 parent 1a68b48 commit 0a668ab
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 27 deletions.
37 changes: 10 additions & 27 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

Check failure on line 21 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:21:81: E501 Line too long (82 > 80)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput, SamplingMetadata

Check failure on line 22 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:22:81: E501 Line too long (87 > 80)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear

Check failure on line 24 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:24:81: E501 Line too long (85 > 80)
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import (
LlamaAttention,
Expand Down Expand Up @@ -54,29 +55,10 @@
import torch

Check failure on line 55 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

vllm/model_executor/models/aria.py:55:1: E402 Module level import not at top of file

Check failure on line 55 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F811)

vllm/model_executor/models/aria.py:55:8: F811 Redefinition of unused `torch` from line 5
import torch.nn as nn

Check failure on line 56 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

vllm/model_executor/models/aria.py:56:1: E402 Module level import not at top of file
from torch.nn.init import trunc_normal_
from transformers.activations import ACT2FN
from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
from vllm.config import QuantizationConfig
from vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer


class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"


class IdentityOp(torch.nn.Module):
"""
An identity operation that returns the input unchanged.
This can be used as a placeholder or to maintain architectural consistency
when a specific operation is not needed.
"""

def __init__(self, *args, **kwargs):
super().__init__()

def forward(self, x, *args, **kwargs):
return x
from vllm.transformers_utils.configs.aria import AriaVisionConfig
from vllm.model_executor.layers.activation import get_act_fn


class AriaVisionTransformer(Idefics2VisionTransformer):
Expand All @@ -88,7 +70,7 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.post_layernorm = IdentityOp()
self.post_layernorm = nn.Identity()


class AriaVisionModel(nn.Module):
Expand Down Expand Up @@ -160,13 +142,14 @@ class FFN(nn.Module):

def __init__(self, embed_dim, ff_dim, output_dim):
super().__init__()
self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False)
self.linear_out = nn.Linear(ff_dim, output_dim, bias=False)
self.act = ACT2FN["gelu_new"]
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.act = get_act_fn("gelu_new")

def forward(self, hidden_states):
hidden_states = self.act(self.linear_in(hidden_states))
hidden_states = self.linear_out(hidden_states)
hidden_states, _ = self.linear_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_out(hidden_states)
return hidden_states


Expand Down
5 changes: 5 additions & 0 deletions vllm/transformers_utils/configs/aria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig


class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"

0 comments on commit 0a668ab

Please sign in to comment.