Skip to content

Commit

Permalink
refactor __main__.py for finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhainebula committed Nov 15, 2024
1 parent 974ce9a commit a08f352
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 97 deletions.
35 changes: 20 additions & 15 deletions FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@
)


parser = HfArgumentParser((
DecoderOnlyEmbedderModelArguments,
DecoderOnlyEmbedderDataArguments,
DecoderOnlyEmbedderTrainingArguments
))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: DecoderOnlyEmbedderModelArguments
data_args: DecoderOnlyEmbedderDataArguments
training_args: DecoderOnlyEmbedderTrainingArguments
def main():
parser = HfArgumentParser((
DecoderOnlyEmbedderModelArguments,
DecoderOnlyEmbedderDataArguments,
DecoderOnlyEmbedderTrainingArguments
))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: DecoderOnlyEmbedderModelArguments
data_args: DecoderOnlyEmbedderDataArguments
training_args: DecoderOnlyEmbedderTrainingArguments

runner = DecoderOnlyEmbedderRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
runner = DecoderOnlyEmbedderRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()
35 changes: 20 additions & 15 deletions FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@
)


parser = HfArgumentParser((
DecoderOnlyEmbedderICLModelArguments,
DecoderOnlyEmbedderICLDataArguments,
DecoderOnlyEmbedderICLTrainingArguments
))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: DecoderOnlyEmbedderICLModelArguments
data_args: DecoderOnlyEmbedderICLDataArguments
training_args: DecoderOnlyEmbedderICLTrainingArguments
def main():
parser = HfArgumentParser((
DecoderOnlyEmbedderICLModelArguments,
DecoderOnlyEmbedderICLDataArguments,
DecoderOnlyEmbedderICLTrainingArguments
))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: DecoderOnlyEmbedderICLModelArguments
data_args: DecoderOnlyEmbedderICLDataArguments
training_args: DecoderOnlyEmbedderICLTrainingArguments

runner = DecoderOnlyEmbedderICLRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
runner = DecoderOnlyEmbedderICLRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()
35 changes: 20 additions & 15 deletions FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@
)


parser = HfArgumentParser((
EncoderOnlyEmbedderModelArguments,
EncoderOnlyEmbedderDataArguments,
EncoderOnlyEmbedderTrainingArguments
))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: EncoderOnlyEmbedderModelArguments
data_args: EncoderOnlyEmbedderDataArguments
training_args: EncoderOnlyEmbedderTrainingArguments
def main():
parser = HfArgumentParser((
EncoderOnlyEmbedderModelArguments,
EncoderOnlyEmbedderDataArguments,
EncoderOnlyEmbedderTrainingArguments
))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: EncoderOnlyEmbedderModelArguments
data_args: EncoderOnlyEmbedderDataArguments
training_args: EncoderOnlyEmbedderTrainingArguments

runner = EncoderOnlyEmbedderRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
runner = EncoderOnlyEmbedderRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()
27 changes: 16 additions & 11 deletions FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
)


parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: EncoderOnlyEmbedderM3ModelArguments
data_args: EncoderOnlyEmbedderM3DataArguments
training_args: EncoderOnlyEmbedderM3TrainingArguments
def main():
parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: EncoderOnlyEmbedderM3ModelArguments
data_args: EncoderOnlyEmbedderM3DataArguments
training_args: EncoderOnlyEmbedderM3TrainingArguments

runner = EncoderOnlyEmbedderM3Runner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
runner = EncoderOnlyEmbedderM3Runner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
__all__ = [
"CrossDecoderModel",
"DecoderOnlyRerankerRunner",
"DecoderOnlyRerankerTrainer"
"DecoderOnlyRerankerTrainer",
"RerankerModelArguments",
]
34 changes: 21 additions & 13 deletions FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@
AbsRerankerTrainingArguments
)

from FlagEmbedding.finetune.reranker.decoder_only.base.runner import DecoderOnlyRerankerRunner
from FlagEmbedding.finetune.reranker.decoder_only.base.arguments import RerankerModelArguments
from FlagEmbedding.finetune.reranker.decoder_only.base import (
DecoderOnlyRerankerRunner,
RerankerModelArguments
)

parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: RerankerModelArguments
data_args: AbsRerankerDataArguments
training_args: AbsRerankerTrainingArguments

runner = DecoderOnlyRerankerRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
def main():
parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: RerankerModelArguments
data_args: AbsRerankerDataArguments
training_args: AbsRerankerTrainingArguments

runner = DecoderOnlyRerankerRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
__all__ = [
"CrossDecoderModel",
"DecoderOnlyRerankerRunner",
"DecoderOnlyRerankerTrainer"
"DecoderOnlyRerankerTrainer",
"RerankerModelArguments",
]
35 changes: 21 additions & 14 deletions FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from transformers import HfArgumentParser

from FlagEmbedding.abc.finetune.reranker import (
AbsRerankerModelArguments,
AbsRerankerDataArguments,
AbsRerankerTrainingArguments
)

from FlagEmbedding.finetune.reranker.decoder_only.layerwise.runner import DecoderOnlyRerankerRunner
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.arguments import RerankerModelArguments
from FlagEmbedding.finetune.reranker.decoder_only.layerwise import (
DecoderOnlyRerankerRunner,
RerankerModelArguments
)

parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: RerankerModelArguments
data_args: AbsRerankerDataArguments
training_args: AbsRerankerTrainingArguments

runner = DecoderOnlyRerankerRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
def main():
parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: RerankerModelArguments
data_args: AbsRerankerDataArguments
training_args: AbsRerankerTrainingArguments

runner = DecoderOnlyRerankerRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()
29 changes: 17 additions & 12 deletions FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
AbsRerankerDataArguments,
AbsRerankerTrainingArguments
)
from FlagEmbedding.finetune.reranker.encoder_only.base.runner import EncoderOnlyRerankerRunner
from FlagEmbedding.finetune.reranker.encoder_only.base import EncoderOnlyRerankerRunner


parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: AbsRerankerModelArguments
data_args: AbsRerankerDataArguments
training_args: AbsRerankerTrainingArguments
def main():
parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: AbsRerankerModelArguments
data_args: AbsRerankerDataArguments
training_args: AbsRerankerTrainingArguments

runner = EncoderOnlyRerankerRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()
runner = EncoderOnlyRerankerRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args
)
runner.run()


if __name__ == "__main__":
main()

0 comments on commit a08f352

Please sign in to comment.