diff --git a/aria/model/configuration_aria.py b/aria/model/configuration_aria.py index 7bfe8fa..20d9704 100644 --- a/aria/model/configuration_aria.py +++ b/aria/model/configuration_aria.py @@ -66,12 +66,13 @@ def __init__( }, ignore_index=-100, image_token_index=32000, + tie_word_embeddings=False, **kwargs, ): super().__init__(**kwargs) self.ignore_index = ignore_index self.image_token_index = image_token_index - + self.tie_word_embeddings = tie_word_embeddings attn_implementation = kwargs.pop("attn_implementation", None) # Set the default attention implementation to flash_attention_2 if not specified diff --git a/aria/model/modeling_aria.py b/aria/model/modeling_aria.py index 9809816..3d89a05 100644 --- a/aria/model/modeling_aria.py +++ b/aria/model/modeling_aria.py @@ -164,11 +164,11 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): """Set the input embeddings for the language model.""" self.language_model.set_input_embeddings(value) - + def get_output_embeddings(self): """Retrieve the output embeddings from the language model.""" return self.language_model.get_output_embeddings() - + def set_output_embeddings(self, value): """Set the output embeddings for the language model.""" self.language_model.set_output_embeddings(value)