Skip to content

Commit

Permalink
Merge pull request #85 from rhymes-ai/fix_output
Browse files Browse the repository at this point in the history
fix: disable tie_embedding
  • Loading branch information
xffxff authored Nov 28, 2024
2 parents 485a6f6 + 5b61ee8 commit 17cb159
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion aria/model/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ def __init__(
},
ignore_index=-100,
image_token_index=32000,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
self.ignore_index = ignore_index
self.image_token_index = image_token_index

self.tie_word_embeddings = tie_word_embeddings
attn_implementation = kwargs.pop("attn_implementation", None)

# Set the default attention implementation to flash_attention_2 if not specified
Expand Down
4 changes: 2 additions & 2 deletions aria/model/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
"""Set the input embeddings for the language model."""
self.language_model.set_input_embeddings(value)

def get_output_embeddings(self):
"""Retrieve the output embeddings from the language model."""
return self.language_model.get_output_embeddings()

def set_output_embeddings(self, value):
"""Set the output embeddings for the language model."""
self.language_model.set_output_embeddings(value)
Expand Down

0 comments on commit 17cb159

Please sign in to comment.