From 165837071089100caeabb364abe81085d976bc72 Mon Sep 17 00:00:00 2001 From: TJian Date: Wed, 16 Oct 2024 08:34:26 -0700 Subject: [PATCH] [Model] [BUG] Fix code path logic to load mllama model (#234) * fix code path logic to load mllama model * fix lint error * fix lint error --------- Co-authored-by: tjtanaa --- vllm/attention/backends/utils.py | 57 +++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e451cd5522d18..f3e5670b7b110 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -7,7 +7,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) -from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.utils import async_tensor_h2d, is_hip, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -334,11 +334,19 @@ def graph_capture_get_metadata_for_batch( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._update_captured_metadata_for_enc_dec_model( - batch_size=batch_size, attn_metadata=attn_metadata) + if is_hip(): + assert ( + self.runner.attn_backend.get_name() == "rocm-flash-attn" + ), (f"Expected attn_backend name to be 'rocm-flash-attn', but " + f" got '{self.runner.attn_backend.get_name()}'") + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + else: + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) return attn_metadata @@ -354,11 +362,19 @@ def get_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._add_additonal_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) + if is_hip(): + assert ( + self.runner.attn_backend.get_name() == "rocm-flash-attn" + ), (f"Expected attn_backend name to be 'rocm-flash-attn', but " + f" got '{self.runner.attn_backend.get_name()}'") + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + else: + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers def prepare_graph_input_buffers( @@ -373,11 +389,20 @@ def prepare_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ - f" got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model( - attn_metadata, input_buffers) + + if is_hip(): + assert ( + self.runner.attn_backend.get_name() == "rocm-flash-attn" + ), (f"Expected attn_backend name to be 'rocm-flash-attn', but " + f" got '{self.runner.attn_backend.get_name()}'") + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + else: + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return