Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 1, 2024
1 parent f6124f2 commit 6a79015
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
18 changes: 14 additions & 4 deletions mace/cli/convert_dev.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from argparse import ArgumentParser

import torch


def main():
parser = ArgumentParser()
parser.add_argument("--target_device", "-t",
help="device to convert to, usually 'cpu' or 'cuda'", default="cpu")
parser.add_argument("--output_file", "-o",
help="name for output model, defaults to model_file.target_device")
parser.add_argument(
"--target_device",
"-t",
help="device to convert to, usually 'cpu' or 'cuda'",
default="cpu",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

Expand All @@ -17,5 +26,6 @@ def main():
model.to(args.target_device)
torch.save(model, args.output_file)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def parse_args() -> argparse.Namespace:
help="Model head used for evaluation",
type=str,
required=False,
default=None
default=None,
)
return parser.parse_args()

Expand Down Expand Up @@ -94,7 +94,7 @@ def run(args: argparse.Namespace) -> None:
heads = model.heads
except AttributeError:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import subprocess
import sys
Expand All @@ -7,7 +8,6 @@
import numpy as np
import pytest
from ase.atoms import Atoms
import json

from mace.calculators.mace import MACECalculator

Expand Down Expand Up @@ -608,7 +608,7 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs):

if i in (0, 1):
continue # skip isolated atoms, as energies specified by json files below
elif i % 2 == 0:
if i % 2 == 0:
c.info["head"] = "DFT"
fitting_configs_dft.append(c)
else:
Expand All @@ -619,9 +619,9 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs):

# write E0s to json files
E0s = {1: 0.0, 8: 0.0}
with open(tmp_path / "fit_multihead_dft.json", "w") as f:
with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f:
json.dump(E0s, f)
with open(tmp_path / "fit_multihead_mp2.json", "w") as f:
with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f:
json.dump(E0s, f)

heads = {
Expand Down

0 comments on commit 6a79015

Please sign in to comment.