Skip to content

Commit

Permalink
Merge pull request #84 from rhymes-ai/output_embeddings
Browse files Browse the repository at this point in the history
add set/get_output_embeddings
  • Loading branch information
xffxff authored Nov 28, 2024
2 parents 7cbf499 + d7687b6 commit 485a6f6
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions aria/model/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
"""Set the input embeddings for the language model."""
self.language_model.set_input_embeddings(value)

def get_output_embeddings(self):
"""Retrieve the output embeddings from the language model."""
return self.language_model.get_output_embeddings()

def set_output_embeddings(self, value):
"""Set the output embeddings for the language model."""
self.language_model.set_output_embeddings(value)

def set_moe_z_loss_coeff(self, value):
"""
Expand Down

0 comments on commit 485a6f6

Please sign in to comment.