diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h index 8f69168b5a..b8d59d790a 100644 --- a/source/api_cc/include/DeepPotPT.h +++ b/source/api_cc/include/DeepPotPT.h @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include @@ -335,7 +334,7 @@ class DeepPotPT : public DeepPotBackend { NeighborListData nlist_data; int max_num_neighbors; int gpu_id; - int do_message_passing; // 1:dpa2 model 0:others + bool do_message_passing; // 1:dpa2 model 0:others bool gpu_enabled; at::Tensor firstneigh_tensor; c10::optional mapping_tensor; diff --git a/source/api_cc/include/DeepSpinPT.h b/source/api_cc/include/DeepSpinPT.h index 643557eb07..462bc783d7 100644 --- a/source/api_cc/include/DeepSpinPT.h +++ b/source/api_cc/include/DeepSpinPT.h @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include @@ -257,7 +256,7 @@ class DeepSpinPT : public DeepSpinBackend { NeighborListData nlist_data; int max_num_neighbors; int gpu_id; - int do_message_passing; // 1:dpa2 model 0:others + bool do_message_passing; // 1:dpa2 model 0:others bool gpu_enabled; at::Tensor firstneigh_tensor; c10::optional mapping_tensor; diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index ce104b0f8e..79494f7ed6 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH #include "DeepPotPT.h" @@ -171,7 +170,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - if (do_message_passing == 1) { + if (do_message_passing) { int nswap = lmp_list.nswap; torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); @@ -234,7 +233,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, .to(device); } c10::Dict outputs = - (do_message_passing == 1) + (do_message_passing) ? module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, fparam_tensor, diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 3ae0eb3bb7..1b28274d8b 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH #include "DeepSpinPT.h" @@ -179,7 +178,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - if (do_message_passing == 1) { + if (do_message_passing) { int nswap = lmp_list.nswap; torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); @@ -234,7 +233,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, .to(device); } c10::Dict outputs = - (do_message_passing == 1) + (do_message_passing) ? module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, spin_wrapped_Tensor, firstneigh_tensor,