diff --git a/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py b/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py index a87428a2..bfef14d8 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py @@ -8,19 +8,24 @@ ) -parser = HfArgumentParser(( - DecoderOnlyEmbedderModelArguments, - DecoderOnlyEmbedderDataArguments, - DecoderOnlyEmbedderTrainingArguments -)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: DecoderOnlyEmbedderModelArguments -data_args: DecoderOnlyEmbedderDataArguments -training_args: DecoderOnlyEmbedderTrainingArguments +def main(): + parser = HfArgumentParser(( + DecoderOnlyEmbedderModelArguments, + DecoderOnlyEmbedderDataArguments, + DecoderOnlyEmbedderTrainingArguments + )) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: DecoderOnlyEmbedderModelArguments + data_args: DecoderOnlyEmbedderDataArguments + training_args: DecoderOnlyEmbedderTrainingArguments -runner = DecoderOnlyEmbedderRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = DecoderOnlyEmbedderRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py b/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py index 4354f440..58fe46ae 100644 --- a/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py +++ b/FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py @@ -8,19 +8,24 @@ ) -parser = HfArgumentParser(( - DecoderOnlyEmbedderICLModelArguments, - DecoderOnlyEmbedderICLDataArguments, - DecoderOnlyEmbedderICLTrainingArguments -)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: DecoderOnlyEmbedderICLModelArguments -data_args: DecoderOnlyEmbedderICLDataArguments -training_args: DecoderOnlyEmbedderICLTrainingArguments +def main(): + parser = HfArgumentParser(( + DecoderOnlyEmbedderICLModelArguments, + DecoderOnlyEmbedderICLDataArguments, + DecoderOnlyEmbedderICLTrainingArguments + )) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: DecoderOnlyEmbedderICLModelArguments + data_args: DecoderOnlyEmbedderICLDataArguments + training_args: DecoderOnlyEmbedderICLTrainingArguments -runner = DecoderOnlyEmbedderICLRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = DecoderOnlyEmbedderICLRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py b/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py index a31a8ef8..5c43abfe 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py @@ -8,19 +8,24 @@ ) -parser = HfArgumentParser(( - EncoderOnlyEmbedderModelArguments, - EncoderOnlyEmbedderDataArguments, - EncoderOnlyEmbedderTrainingArguments -)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: EncoderOnlyEmbedderModelArguments -data_args: EncoderOnlyEmbedderDataArguments -training_args: EncoderOnlyEmbedderTrainingArguments +def main(): + parser = HfArgumentParser(( + EncoderOnlyEmbedderModelArguments, + EncoderOnlyEmbedderDataArguments, + EncoderOnlyEmbedderTrainingArguments + )) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: EncoderOnlyEmbedderModelArguments + data_args: EncoderOnlyEmbedderDataArguments + training_args: EncoderOnlyEmbedderTrainingArguments -runner = EncoderOnlyEmbedderRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = EncoderOnlyEmbedderRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py b/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py index cd556573..5a8cba7a 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py @@ -8,15 +8,20 @@ ) -parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: EncoderOnlyEmbedderM3ModelArguments -data_args: EncoderOnlyEmbedderM3DataArguments -training_args: EncoderOnlyEmbedderM3TrainingArguments +def main(): + parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: EncoderOnlyEmbedderM3ModelArguments + data_args: EncoderOnlyEmbedderM3DataArguments + training_args: EncoderOnlyEmbedderM3TrainingArguments -runner = EncoderOnlyEmbedderM3Runner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = EncoderOnlyEmbedderM3Runner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py b/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py index e3a4a063..3243ecb6 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py @@ -6,5 +6,6 @@ __all__ = [ "CrossDecoderModel", "DecoderOnlyRerankerRunner", - "DecoderOnlyRerankerTrainer" + "DecoderOnlyRerankerTrainer", + "RerankerModelArguments", ] diff --git a/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py b/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py index f0cf400e..447e6dc7 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py @@ -5,18 +5,26 @@ AbsRerankerTrainingArguments ) -from FlagEmbedding.finetune.reranker.decoder_only.base.runner import DecoderOnlyRerankerRunner -from FlagEmbedding.finetune.reranker.decoder_only.base.arguments import RerankerModelArguments +from FlagEmbedding.finetune.reranker.decoder_only.base import ( + DecoderOnlyRerankerRunner, + RerankerModelArguments +) -parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: RerankerModelArguments -data_args: AbsRerankerDataArguments -training_args: AbsRerankerTrainingArguments -runner = DecoderOnlyRerankerRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() +def main(): + parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: RerankerModelArguments + data_args: AbsRerankerDataArguments + training_args: AbsRerankerTrainingArguments + + runner = DecoderOnlyRerankerRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py index e3a4a063..3243ecb6 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py @@ -6,5 +6,6 @@ __all__ = [ "CrossDecoderModel", "DecoderOnlyRerankerRunner", - "DecoderOnlyRerankerTrainer" + "DecoderOnlyRerankerTrainer", + "RerankerModelArguments", ] diff --git a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py index d1077fbb..64774fc0 100644 --- a/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py +++ b/FlagEmbedding/finetune/reranker/decoder_only/layerwise/__main__.py @@ -1,23 +1,30 @@ from transformers import HfArgumentParser from FlagEmbedding.abc.finetune.reranker import ( - AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments ) -from FlagEmbedding.finetune.reranker.decoder_only.layerwise.runner import DecoderOnlyRerankerRunner -from FlagEmbedding.finetune.reranker.decoder_only.layerwise.arguments import RerankerModelArguments +from FlagEmbedding.finetune.reranker.decoder_only.layerwise import ( + DecoderOnlyRerankerRunner, + RerankerModelArguments +) -parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: RerankerModelArguments -data_args: AbsRerankerDataArguments -training_args: AbsRerankerTrainingArguments -runner = DecoderOnlyRerankerRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() +def main(): + parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: RerankerModelArguments + data_args: AbsRerankerDataArguments + training_args: AbsRerankerTrainingArguments + + runner = DecoderOnlyRerankerRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py b/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py index 0cccfe93..76778d65 100644 --- a/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py +++ b/FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py @@ -5,18 +5,23 @@ AbsRerankerDataArguments, AbsRerankerTrainingArguments ) -from FlagEmbedding.finetune.reranker.encoder_only.base.runner import EncoderOnlyRerankerRunner +from FlagEmbedding.finetune.reranker.encoder_only.base import EncoderOnlyRerankerRunner -parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) -model_args, data_args, training_args = parser.parse_args_into_dataclasses() -model_args: AbsRerankerModelArguments -data_args: AbsRerankerDataArguments -training_args: AbsRerankerTrainingArguments +def main(): + parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: AbsRerankerModelArguments + data_args: AbsRerankerDataArguments + training_args: AbsRerankerTrainingArguments -runner = EncoderOnlyRerankerRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args -) -runner.run() + runner = EncoderOnlyRerankerRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args + ) + runner.run() + + +if __name__ == "__main__": + main()