Skip to content

Commit

Permalink
Fix default outpout_file in select_head, and add argument to list heads
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Jan 7, 2025
1 parent fece538 commit 5d6e60d
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@

def main():
parser = ArgumentParser()
parser.add_argument(
grp = parser.add_mutually_exclusive_group()
grp.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
grp.add_argument(
"--list_heads",
"-l",
action="store_true",
help="list names of the heads",
)
parser.add_argument(
"--output_file",
"-o",
Expand All @@ -21,12 +28,15 @@ def main():
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)
if args.list_heads:
print("Available heads:")
print("\n".join([" " + h for h in model.heads]))
else:
if args.output_file is None:
args.output_file = args.model_file + "." + args.head_name + "." + str(next(model.parameters()).device)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)


if __name__ == "__main__":
Expand Down

0 comments on commit 5d6e60d

Please sign in to comment.