Skip to content

Commit

Permalink
Merge pull request #467 from ACEsuit/develop
Browse files Browse the repository at this point in the history
change stress normalization + setup.cfg with numpy<2.0
  • Loading branch information
ilyes319 authored Jun 17, 2024
2 parents dee204f + 36eccd6 commit a555f04
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 92 deletions.
99 changes: 54 additions & 45 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import multiprocessing as mp
import os
import random
from functools import partial
from glob import glob
from typing import List, Tuple

Expand Down Expand Up @@ -92,6 +93,27 @@ def get_prime_factors(n: int):
return factors


# Define Task for Multiprocessiing
def multi_train_hdf5(process, args, split_train, drop_last):
with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)


def multi_valid_hdf5(process, args, split_valid, drop_last):
with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)


def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
) as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)


def main() -> None:
"""
This script loads an xyz dataset and prepares
Expand Down Expand Up @@ -172,47 +194,42 @@ def run(args: argparse.Namespace):
if len(collections.train) % 2 == 1:
drop_last = True

# Define Task for Multiprocessiing
def multi_train_hdf5(process):
with h5py.File(args.h5_prefix + "train/train_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)

multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5, args=[i])
p = mp.Process(target=multi_train_hdf5_, args=[i])
p.start()
processes.append(p)

for i in processes:
i.join()


logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")

# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}

with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)
if args.compute_statistics:
logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")

# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}

with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)

logging.info("Preparing validation set")
if args.shuffle:
Expand All @@ -222,36 +239,28 @@ def multi_train_hdf5(process):
if len(collections.valid) % 2 == 1:
drop_last = True

def multi_valid_hdf5(process):
with h5py.File(args.h5_prefix + "val/val_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)

multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5, args=[i])
p = mp.Process(target=multi_valid_hdf5_, args=[i])
p.start()
processes.append(p)

for i in processes:
i.join()

if args.test_file is not None:
def multi_test_hdf5(process, name):
with h5py.File(args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)

logging.info("Preparing test sets")
for name, subset in collections.tests:
drop_last = False
if len(subset) % 2 == 1:
drop_last = True
split_test = np.array_split(subset, args.num_process)
multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last)

processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_test_hdf5, args=[i, name])
p = mp.Process(target=multi_test_hdf5_, args=[i, name])
p.start()
processes.append(p)

Expand Down
4 changes: 2 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def run(args: argparse.Namespace) -> None:
logging.info(
f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license."
)
model_foundation = mace_off(
calc = mace_off(
model=model_type,
device=args.device,
default_dtype=args.default_dtype,
return_raw_model=True,
)
model_foundation = calc.models[0]
else:
model_foundation = torch.load(args.foundation_model, map_location=device)
logging.info(
Expand Down
3 changes: 1 addition & 2 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor:
# energy: [n_graphs, ]
configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ]
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,]
return torch.mean(
configs_weight
* configs_stress_weight
* torch.square((ref["stress"] - pred["stress"]) / num_atoms)
* torch.square(ref["stress"] - pred["stress"])
) # []


Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ python_requires = >=3.7
install_requires =
torch>=1.12
e3nn==0.4.4
numpy
numpy<2.0
opt_einsum
ase
torch-ema
Expand Down
84 changes: 42 additions & 42 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,30 @@ def test_run_train(tmp_path, fitting_configs):
Es.append(at.get_potential_energy())

print("Es", Es)
# from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3
# from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7
ref_Es = [
0.0,
0.0,
-0.03911274694160493,
-0.0913651377675312,
-0.14973695873658766,
-0.0664839502025434,
-0.09968814898703926,
0.1248460531971883,
-0.0647495831154953,
-0.14589298347245963,
0.12918668431788108,
-0.13996496272772996,
-0.053211348522482806,
0.07845141245421094,
-0.08901520083723416,
-0.15467129065263446,
0.007727727865546765,
-0.04502061132025605,
-0.035848783030374,
-0.24410687104937906,
-0.0839034724949955,
-0.14756571357354326,
-0.039181344585828524,
-0.0915223395136733,
-0.14953484236456582,
-0.06662480820063998,
-0.09983737353050133,
0.12477442296789745,
-0.06486086271762856,
-0.1460607988519944,
0.12886334908465508,
-0.14000990081920373,
-0.05319886578958313,
0.07780520158391,
-0.08895480281886901,
-0.15474719614734422,
0.007756765146527644,
-0.044879267197498685,
-0.036065736712447574,
-0.24413743841886623,
-0.0838104612106429,
-0.14751978636626545
]

assert np.allclose(Es, ref_Es)
Expand Down Expand Up @@ -178,30 +178,30 @@ def test_run_train_missing_data(tmp_path, fitting_configs):
Es.append(at.get_potential_energy())

print("Es", Es)
# from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3
# from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7
ref_Es = [
0.0,
0.0,
-0.05449966431966507,
-0.11237663925685797,
0.03914539466246801,
-0.07500800414261456,
-0.13471106701173396,
0.02937255038020199,
-0.0652196693921633,
-0.14946129637190012,
0.19412338220281133,
-0.13546947741234333,
-0.05235148626886153,
-0.04957190959243316,
-0.07081384032242896,
-0.24575839901841345,
-0.0020512332640394916,
-0.038630330106902526,
-0.13621347044601181,
-0.2338465954158298,
-0.11777474787291177,
-0.14895508008918812,
-0.05464025113696155,
-0.11272131295940478,
0.039200919331076826,
-0.07517990972827505,
-0.13504202474582666,
0.0292022872055344,
-0.06541099574579018,
-0.1497824717832886,
0.19397709360828813,
-0.13587609467143014,
-0.05242956276828463,
-0.0504862057364953,
-0.07095795959430119,
-0.2463753796753703,
-0.002031543147676121,
-0.03864918790300681,
-0.13680153117705554,
-0.23418951968636786,
-0.11790833839379238,
-0.14930562311066484
]
assert np.allclose(Es, ref_Es)

Expand Down

0 comments on commit a555f04

Please sign in to comment.