diff --git a/scripts/get_next_logprobs.py b/scripts/get_next_logprobs.py index 1ba0365c..5cf0d26e 100755 --- a/scripts/get_next_logprobs.py +++ b/scripts/get_next_logprobs.py @@ -16,7 +16,7 @@ def main( in_model_repo_id: str, - revisions: Iterable[str], + branches: Iterable[str], in_dataset_repo_id: str, split: str, feature: str, @@ -31,11 +31,9 @@ def main( in_dataset_repo_id, split, feature ) in_dataset_split.set_format("torch") - for revision in revisions: - print(f"Loading model={in_model_repo_id}, {revision=}") - model = AutoModelForCausalLM.from_pretrained( - in_model_repo_id, revision=revision - ) + for branch in branches: + print(f"Loading model='{in_model_repo_id}', {branch=}") + model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch) logprobs_dataset = get_logprobs_single_model( model=model, dataset=in_dataset_split, @@ -45,7 +43,7 @@ def main( logprobs_dataset.push_to_hub( repo_id=out_repo_id, split=utils.hf_split_to_split_name(split), - revision=revision, + revision=branch, ) @@ -80,9 +78,8 @@ def get_logprobs_single_model( help="The model", ) parser.add_argument( - "--revisions", - "-r", - help="comma separated revisions of the model to use or 'ALL_BRANCHES' to use all branches", + "--branches", + help="comma separated branches of the model to use or 'ALL' to use all branches", type=str, default="main", required=False, @@ -133,15 +130,15 @@ def get_logprobs_single_model( # ) args = parser.parse_args() - revisions = ( - args.revisions.split(",") - if args.revisions != "ALL_BRANCHES" + branches = ( + args.branches.split(",") + if args.branches != "ALL" else utils.get_all_hf_branch_names(args.in_model_repo_id) ) main( in_model_repo_id=args.in_model_repo_id, - revisions=revisions, + branches=branches, in_dataset_repo_id=args.in_dataset_repo_id, split=args.split, feature=args.feature,