From 2c723b4d66c3142318f398d5be39cb183c543205 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 18 Nov 2023 13:49:22 -0800 Subject: [PATCH] Fix `BatchedGraph.from_graphs` RuntimeError from mismatching dtypes (#95) * add comment with link to GH issue why mlp_out_bias=model_name == "0.2.0" in chgnet.load() * fix BatchedGraph.from_graphs RuntimeError: expected m1 and m2 to have the same dtype, but got: float != double ase/optimize/fire.py", line 54, in __init__ Optimizer.__init__(self, atoms, restart, logfile, trajectory, ase/optimize/optimize.py", line 234, in __init__ self.set_force_consistent() ase/optimize/optimize.py", line 325, in set_force_consistent self.atoms.get_potential_energy(force_consistent=True) ase/atoms.py", line 728, in get_potential_energy energy = self._calc.get_potential_energy( ... lattice = graph.lattice @ (torch.eye(3).to(strain.device) + strain) ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ * Update pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- chgnet/model/model.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a6d7ef5..20e069bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.1.6 hooks: - id: ruff args: [--fix] @@ -46,7 +46,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v8.53.0 + rev: v8.54.0 hooks: - id: eslint types: [file] diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 83e9397a..e337c9d7 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -687,6 +687,9 @@ def load(cls, model_name="0.3.0"): return cls.from_file( os.path.join(module_dir, checkpoint_path), + # mlp_out_bias=True is set for backward compatible behavior but in rare + # cases causes unphysical jumps in bonding energy. see + # https://github.com/CederGroupHub/chgnet/issues/79 mlp_out_bias=model_name == "0.2.0", version=model_name, ) @@ -753,7 +756,7 @@ def from_graphs( compute_stress (bool): whether to compute stress. Default = False Returns: - assembled batch_graph that is ready for batched forward pass in CHGNet + BatchedGraph: assembled graphs ready for batched CHGNet forward pass """ atomic_numbers, atom_positions = [], [] strains, volumes = [], [] @@ -772,7 +775,9 @@ def from_graphs( # Lattice if compute_stress: strain = graph.lattice.new_zeros([3, 3], requires_grad=True) - lattice = graph.lattice @ (torch.eye(3).to(strain.device) + strain) + lattice = graph.lattice @ ( + torch.eye(3, dtype=datatype).to(strain.device) + strain + ) else: strain = None lattice = graph.lattice