Skip to content

Commit

Permalink
Merge pull request #348 from torchmd/fix_torch_warnings
Browse files Browse the repository at this point in the history
Fixed pytorch deprecations warnings
  • Loading branch information
stefdoerr authored Dec 3, 2024
2 parents f6c0c16 + 535c5ae commit 0b94d88
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
17 changes: 13 additions & 4 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@

import pytest
from pytest import mark
import torch as pt
from torchmdnet.models.model import create_model
from torchmdnet.optimize import optimize
from torchmdnet.models.utils import dtype_mapping

try:
import NNPOps

nnpops_available = True
except ImportError:
nnpops_available = False


@pytest.mark.skipif(not nnpops_available, reason="NNPOps not available")
@mark.parametrize("device", ["cpu", "cuda"])
@mark.parametrize("num_atoms", [10, 100])
def test_gn(device, num_atoms):
import torch as pt
from torchmdnet.models.model import create_model
from torchmdnet.optimize import optimize
from torchmdnet.models.utils import dtype_mapping

if not pt.cuda.is_available() and device == "cuda":
pytest.skip("No GPU")
Expand Down
6 changes: 3 additions & 3 deletions torchmdnet/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def get_neighbor_pairs_fwd_meta(


if torch.__version__ >= "2.2.0":
from torch.library import impl_abstract
from torch.library import register_fake

impl_abstract(
register_fake(
"torchmdnet_extensions::get_neighbor_pairs_bkwd", get_neighbor_pairs_bkwd_meta
)
impl_abstract(
register_fake(
"torchmdnet_extensions::get_neighbor_pairs_fwd", get_neighbor_pairs_fwd_meta
)
elif torch.__version__ < "2.2.0" and torch.__version__ >= "2.0.0":
Expand Down
4 changes: 2 additions & 2 deletions torchmdnet/extensions/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using std::tuple;
using torch::arange;
using torch::div;
using torch::frobenius_norm;
using torch::linalg_vector_norm;
using torch::full;
using torch::hstack;
using torch::index_select;
Expand Down Expand Up @@ -99,7 +99,7 @@ forward_impl(const std::string& strategy, const Tensor& positions, const Tensor&
deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) -
scale1 * box_vectors.index({pair_batch, 0, 0}));
}
distances = frobenius_norm(deltas, 1);
distances = linalg_vector_norm(deltas, 2, 1);
mask = (distances < cutoff_upper) * (distances >= cutoff_lower);
neighbors = neighbors.index({Slice(), mask});
deltas = deltas.index({mask, Slice()});
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
filepath, args=args, device=device, return_std=return_std, **kwargs
)
assert isinstance(filepath, str)
ckpt = torch.load(filepath, map_location="cpu")
ckpt = torch.load(filepath, map_location="cpu", weights_only=False)
if args is None:
args = ckpt["hyper_parameters"]

Expand Down

0 comments on commit 0b94d88

Please sign in to comment.