Skip to content

Commit

Permalink
fix(pt): use eval mode in the C++ interface (#4402)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
	- Enhanced error handling for improved clarity during execution.
- Added evaluation mode setting for models to ensure correct inference
behavior.
- **Improvements**
- Streamlined device management in computation methods for better
resource utilization.
	- Improved flexibility in handling tensor creation and data transfer.
- Updated computation methods to support new logic for message passing
and mapping tensors.
- Introduced new methods for mixed-type computations, enhancing
functionality.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Nov 23, 2024
1 parent 5d589da commit cb7a0d3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void DeepPotPT::init(const std::string& model,
}
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
module = torch::jit::load(model, device, metadata);
module.eval();
do_message_passing = module.run_method("has_message_passing").toBool();
torch::jit::FusionStrategy strategy;
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void DeepSpinPT::init(const std::string& model,
}
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
module = torch::jit::load(model, device, metadata);
module.eval();
do_message_passing = module.run_method("has_message_passing").toBool();
torch::jit::FusionStrategy strategy;
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};
Expand Down

0 comments on commit cb7a0d3

Please sign in to comment.