Skip to content

Commit

Permalink
Fix BatchedGraph.from_graphs RuntimeError from mismatching dtypes (#95
Browse files Browse the repository at this point in the history
)

* add comment with link to GH issue why mlp_out_bias=model_name == "0.2.0" in chgnet.load()

* fix BatchedGraph.from_graphs RuntimeError: expected m1 and m2 to have the same dtype, but got: float != double

ase/optimize/fire.py", line 54, in __init__
    Optimizer.__init__(self, atoms, restart, logfile, trajectory,
ase/optimize/optimize.py", line 234, in __init__
    self.set_force_consistent()
ase/optimize/optimize.py", line 325, in set_force_consistent
    self.atoms.get_potential_energy(force_consistent=True)
ase/atoms.py", line 728, in get_potential_energy
    energy = self._calc.get_potential_energy(
...
    lattice = graph.lattice @ (torch.eye(3).to(strain.device) + strain)
              ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

* Update pre-commit hooks
  • Loading branch information
janosh authored Nov 18, 2023
1 parent 36c2060 commit 2c723b4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.5
rev: v0.1.6
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.53.0
rev: v8.54.0
hooks:
- id: eslint
types: [file]
Expand Down
9 changes: 7 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,9 @@ def load(cls, model_name="0.3.0"):

return cls.from_file(
os.path.join(module_dir, checkpoint_path),
# mlp_out_bias=True is set for backward compatible behavior but in rare
# cases causes unphysical jumps in bonding energy. see
# https://github.com/CederGroupHub/chgnet/issues/79
mlp_out_bias=model_name == "0.2.0",
version=model_name,
)
Expand Down Expand Up @@ -753,7 +756,7 @@ def from_graphs(
compute_stress (bool): whether to compute stress. Default = False
Returns:
assembled batch_graph that is ready for batched forward pass in CHGNet
BatchedGraph: assembled graphs ready for batched CHGNet forward pass
"""
atomic_numbers, atom_positions = [], []
strains, volumes = [], []
Expand All @@ -772,7 +775,9 @@ def from_graphs(
# Lattice
if compute_stress:
strain = graph.lattice.new_zeros([3, 3], requires_grad=True)
lattice = graph.lattice @ (torch.eye(3).to(strain.device) + strain)
lattice = graph.lattice @ (
torch.eye(3, dtype=datatype).to(strain.device) + strain
)
else:
strain = None
lattice = graph.lattice
Expand Down

0 comments on commit 2c723b4

Please sign in to comment.