diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a6d7ef5..20e069bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.5 + rev: v0.1.6 hooks: - id: ruff args: [--fix] @@ -46,7 +46,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v8.53.0 + rev: v8.54.0 hooks: - id: eslint types: [file] diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 83e9397a..e337c9d7 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -687,6 +687,9 @@ def load(cls, model_name="0.3.0"): return cls.from_file( os.path.join(module_dir, checkpoint_path), + # mlp_out_bias=True is set for backward compatible behavior but in rare + # cases causes unphysical jumps in bonding energy. see + # https://github.com/CederGroupHub/chgnet/issues/79 mlp_out_bias=model_name == "0.2.0", version=model_name, ) @@ -753,7 +756,7 @@ def from_graphs( compute_stress (bool): whether to compute stress. Default = False Returns: - assembled batch_graph that is ready for batched forward pass in CHGNet + BatchedGraph: assembled graphs ready for batched CHGNet forward pass """ atomic_numbers, atom_positions = [], [] strains, volumes = [], [] @@ -772,7 +775,9 @@ def from_graphs( # Lattice if compute_stress: strain = graph.lattice.new_zeros([3, 3], requires_grad=True) - lattice = graph.lattice @ (torch.eye(3).to(strain.device) + strain) + lattice = graph.lattice @ ( + torch.eye(3, dtype=datatype).to(strain.device) + strain + ) else: strain = None lattice = graph.lattice