From 38dd2ac13d10ae7ac635c862b96cfdcea74b8f54 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Tue, 19 Nov 2024 08:14:30 +0000 Subject: [PATCH] make format happy --- aria/vllm/aria.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/aria/vllm/aria.py b/aria/vllm/aria.py index 7c0c31e..6bc666b 100644 --- a/aria/vllm/aria.py +++ b/aria/vllm/aria.py @@ -26,15 +26,14 @@ from transformers import LlamaConfig from transformers.utils import logging from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig -from vllm.config import VllmConfig +from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.inputs import INPUT_REGISTRY, LLMInputs, token_inputs +from vllm.inputs import INPUT_REGISTRY, token_inputs from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -338,8 +337,8 @@ def __init__( prefix: str = "", ) -> None: nn.Module.__init__(self) - - # FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'. + + # FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'. self.do_not_compile = True self.config = config @@ -692,7 +691,12 @@ def input_processor(ctx, llm_inputs): repeat_count=image_feature_sizes, ) - return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, multi_modal_placeholders={"image": ranges}) + return token_inputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}, + ) # adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration @@ -711,10 +715,6 @@ def __init__( self, vllm_config: VllmConfig, prefix: str = "", - # config: AriaConfig, - # multimodal_config: MultiModalConfig, - # cache_config: Optional[CacheConfig] = None, - # quant_config: Optional[QuantizationConfig] = None, ): super().__init__() config = vllm_config.model_config.hf_config