Skip to content

Commit

Permalink
make format happy
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 19, 2024
1 parent 2f84a50 commit 38dd2ac
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
from transformers import LlamaConfig
from transformers.utils import logging
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.config import VllmConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.inputs import INPUT_REGISTRY, LLMInputs, token_inputs
from vllm.inputs import INPUT_REGISTRY, token_inputs
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -338,8 +337,8 @@ def __init__(
prefix: str = "",
) -> None:
nn.Module.__init__(self)
# FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'.

# FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'.
self.do_not_compile = True

self.config = config
Expand Down Expand Up @@ -692,7 +691,12 @@ def input_processor(ctx, llm_inputs):
repeat_count=image_feature_sizes,
)

return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, multi_modal_placeholders={"image": ranges})
return token_inputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges},
)


# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
Expand All @@ -711,10 +715,6 @@ def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
# config: AriaConfig,
# multimodal_config: MultiModalConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down

0 comments on commit 38dd2ac

Please sign in to comment.