Skip to content

Commit

Permalink
revisions -> branches, tested
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed Apr 27, 2024
1 parent 3799ae5 commit 79ce78c
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions scripts/get_next_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def main(
in_model_repo_id: str,
revisions: Iterable[str],
branches: Iterable[str],
in_dataset_repo_id: str,
split: str,
feature: str,
Expand All @@ -31,11 +31,9 @@ def main(
in_dataset_repo_id, split, feature
)
in_dataset_split.set_format("torch")
for revision in revisions:
print(f"Loading model={in_model_repo_id}, {revision=}")
model = AutoModelForCausalLM.from_pretrained(
in_model_repo_id, revision=revision
)
for branch in branches:
print(f"Loading model='{in_model_repo_id}', {branch=}")
model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch)
logprobs_dataset = get_logprobs_single_model(
model=model,
dataset=in_dataset_split,
Expand All @@ -45,7 +43,7 @@ def main(
logprobs_dataset.push_to_hub(
repo_id=out_repo_id,
split=utils.hf_split_to_split_name(split),
revision=revision,
revision=branch,
)


Expand Down Expand Up @@ -80,9 +78,8 @@ def get_logprobs_single_model(
help="The model",
)
parser.add_argument(
"--revisions",
"-r",
help="comma separated revisions of the model to use or 'ALL_BRANCHES' to use all branches",
"--branches",
help="comma separated branches of the model to use or 'ALL' to use all branches",
type=str,
default="main",
required=False,
Expand Down Expand Up @@ -133,15 +130,15 @@ def get_logprobs_single_model(
# )
args = parser.parse_args()

revisions = (
args.revisions.split(",")
if args.revisions != "ALL_BRANCHES"
branches = (
args.branches.split(",")
if args.branches != "ALL"
else utils.get_all_hf_branch_names(args.in_model_repo_id)
)

main(
in_model_repo_id=args.in_model_repo_id,
revisions=revisions,
branches=branches,
in_dataset_repo_id=args.in_dataset_repo_id,
split=args.split,
feature=args.feature,
Expand Down

0 comments on commit 79ce78c

Please sign in to comment.