diff --git a/aria/model/configuration_aria.py b/aria/model/configuration_aria.py index 2b7bb7e..7bfe8fa 100644 --- a/aria/model/configuration_aria.py +++ b/aria/model/configuration_aria.py @@ -73,7 +73,11 @@ def __init__( self.image_token_index = image_token_index attn_implementation = kwargs.pop("attn_implementation", None) - self._attn_implementation = attn_implementation + + # Set the default attention implementation to flash_attention_2 if not specified + self._attn_implementation = ( + "flash_attention_2" if attn_implementation is None else attn_implementation + ) # Convert the keys and values of projector_patch_to_query_dict to integers # This ensures consistency even if they were provided as strings