diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index d985eb4951..d0430d39d0 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -37,8 +37,13 @@ void DeepPot::init(const std::string& model, << std::endl; return; } - // TODO: To implement detect_backend - DPBackend backend = deepmd::DPBackend::TensorFlow; + if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { + DPBackend backend = deepmd::DPBackend::PyTorch; + } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { + DPBackend backend = deepmd::DPBackend::TensorFlow; + else { + throw deepmd::deepmd_exception("Unsupported model file format"); + } if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content);