From d7687b6395f6c974006df23b81671784a2e8012c Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Thu, 28 Nov 2024 15:24:52 +0800 Subject: [PATCH] add set/get_output_embeddings --- aria/model/modeling_aria.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aria/model/modeling_aria.py b/aria/model/modeling_aria.py index 87d4f10..9809816 100644 --- a/aria/model/modeling_aria.py +++ b/aria/model/modeling_aria.py @@ -164,6 +164,14 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): """Set the input embeddings for the language model.""" self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + """Retrieve the output embeddings from the language model.""" + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, value): + """Set the output embeddings for the language model.""" + self.language_model.set_output_embeddings(value) def set_moe_z_loss_coeff(self, value): """