Skip to content

Commit

Permalink
Add mace_convert_dev cli tool to convert between devices
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Oct 30, 2024
1 parent c1bb3b2 commit f6124f2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
21 changes: 21 additions & 0 deletions mace/cli/convert_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from argparse import ArgumentParser
import torch

def main():
parser = ArgumentParser()
parser.add_argument("--target_device", "-t",
help="device to convert to, usually 'cpu' or 'cuda'", default="cpu")
parser.add_argument("--output_file", "-o",
help="name for output model, defaults to model_file.target_device")
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model.to(args.target_device)
torch.save(model, args.output_file)

if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ console_scripts =
mace_run_train = mace.cli.run_train:main
mace_prepare_data = mace.cli.preprocess_data:main
mace_finetuning = mace.cli.fine_tuning_select:main
mace_convert_dev = mace.cli.convert_dev:main

[options.extras_require]
wandb = wandb
Expand All @@ -54,4 +55,4 @@ dev =
pytest
pytest-benchmark
pylint
schedulefree = schedulefree
schedulefree = schedulefree

0 comments on commit f6124f2

Please sign in to comment.