Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support eager attention #30

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aria/model/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(
self.ignore_index = ignore_index
self.image_token_index = image_token_index

attn_implementation = kwargs.pop("attn_implementation", None)

# Convert the keys and values of projector_patch_to_query_dict to integers
# This ensures consistency even if they were provided as strings
self.projector_patch_to_query_dict = {
Expand All @@ -76,10 +78,20 @@ def __init__(

if isinstance(vision_config, dict) and "model_type" in vision_config:
vision_config = AriaVisionConfig(**vision_config)
vision_attn_implementation = (
"flash_attention_2"
if attn_implementation is None
else attn_implementation
)
vision_config._attn_implementation = vision_attn_implementation

self.vision_config = vision_config

if isinstance(text_config, dict) and "model_type" in text_config:
text_attn_implementation = (
"sdpa" if attn_implementation is None else attn_implementation
)
text_config = AriaMoELMConfig(**text_config)
text_config._attn_implementation = text_attn_implementation

self.text_config = text_config
2 changes: 1 addition & 1 deletion aria/model/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self._attn_implementation = "flash_attention_2"


class IdentityOp(torch.nn.Module):
Expand Down Expand Up @@ -83,6 +82,7 @@ class AriaVisionModel(SiglipVisionModel):

config_class = AriaVisionConfig
main_input_name = "pixel_values"
_supports_sdpa = False

def __init__(self, config: AriaVisionConfig):
super().__init__(config)
Expand Down
Loading