diff --git a/.github/workflows/test_openmm_dmff_plugin.yml b/.github/workflows/test_openmm_dmff_plugin.yml index f6da47259..a8b7a3c86 100644 --- a/.github/workflows/test_openmm_dmff_plugin.yml +++ b/.github/workflows/test_openmm_dmff_plugin.yml @@ -41,10 +41,11 @@ jobs: export OPENMM_INSTALLED_DIR=$CONDA_PREFIX export CPPFLOW_INSTALLED_DIR=$CONDA_PREFIX export LIBTENSORFLOW_INSTALLED_DIR=$CONDA_PREFIX - cmake .. -DOPENMM_DIR=${OPENMM_INSTALLED_DIR} -DCPPFLOW_DIR=${CPPFLOW_INSTALLED_DIR} -DTENSORFLOW_DIR=${LIBTENSORFLOW_INSTALLED_DIR} -DUSE_HIGH_PRECISION=ON + cmake .. -DOPENMM_DIR=${OPENMM_INSTALLED_DIR} -DCPPFLOW_DIR=${CPPFLOW_INSTALLED_DIR} -DTENSORFLOW_DIR=${LIBTENSORFLOW_INSTALLED_DIR} -DUSE_HIGH_PRECISION=OFF make && make install - make PythonInstall + make PythonInstall - name: Run Tests run: | source $CONDA/bin/activate dmff_omm - python -m OpenMMDMFFPlugin.tests.test_dmff_plugin_nve -n 100 + cd ${GITHUB_WORKSPACE}/backend/ + python -m OpenMMDMFFPlugin.tests.test_dmff_plugin_nve -n 100 --pdb ../examples/water_fullpol/water_dimer.pdb --model ./openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux --has_aux True diff --git a/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h b/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h index cfacf83c0..4cd428fc8 100644 --- a/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h +++ b/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h @@ -50,7 +50,7 @@ typedef double ENERGYTYPE; #else typedef float FORCETYPE; typedef float COORDTYPE; -typedef float ENERGYTYPE; +typedef double ENERGYTYPE; #endif namespace DMFFPlugin { @@ -71,6 +71,12 @@ class OPENMM_EXPORT_DMFF DMFFForce : public OpenMM::Force { * @param energyCoefficient : the energy transform coefficient. */ void setUnitTransformCoefficients(const double coordCoefficient, const double forceCoefficient, const double energyCoefficient); + /** + * @brief Set the has_aux flag when model was saved with auxilary input. + * + * @param hasAux : true if model was saved with auxilary input. + */ + void setHasAux(const bool hasAux); const std::string& getDMFFGraphFile() const; /** @@ -97,6 +103,13 @@ class OPENMM_EXPORT_DMFF DMFFForce : public OpenMM::Force { * @return double */ double getCutoff() const; + /** + * @brief Get the Has Aux object + * + * @return true + * @return false + */ + bool getHasAux() const; void updateParametersInContext(OpenMM::Context& context); bool usesPeriodicBoundaryConditions() const { return use_pbc; @@ -106,6 +119,7 @@ class OPENMM_EXPORT_DMFF DMFFForce : public OpenMM::Force { private: string graph_file; bool use_pbc = true; + bool has_aux = false; double cutoff = 1.2; double coordCoeff, forceCoeff, energyCoeff; diff --git a/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp b/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp index 093f8dc29..5a48abf00 100644 --- a/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp +++ b/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp @@ -54,12 +54,18 @@ void DMFFForce::setUnitTransformCoefficients(const double coordCoefficient, cons energyCoeff = energyCoefficient; } +void DMFFForce::setHasAux(const bool hasAux){ + this->has_aux = hasAux; +} + double DMFFForce::getCoordUnitCoefficient() const {return coordCoeff;} double DMFFForce::getForceUnitCoefficient() const {return forceCoeff;} double DMFFForce::getEnergyUnitCoefficient() const {return energyCoeff;} double DMFFForce::getCutoff() const {return cutoff;} +bool DMFFForce::getHasAux() const {return has_aux;} + const string& DMFFForce::getDMFFGraphFile() const{return graph_file;} diff --git a/backend/openmm_dmff_plugin/platforms/cuda/include/CudaDMFFKernels.h b/backend/openmm_dmff_plugin/platforms/cuda/include/CudaDMFFKernels.h index 2cdf81ceb..976f874f4 100644 --- a/backend/openmm_dmff_plugin/platforms/cuda/include/CudaDMFFKernels.h +++ b/backend/openmm_dmff_plugin/platforms/cuda/include/CudaDMFFKernels.h @@ -58,19 +58,23 @@ class CudaCalcDMFFForceKernel : public CalcDMFFForceKernel{ std::string graph_file; cppflow::model jax_model; vector coord_shape = vector(2); + vector U_ind_shape = vector(2); vector box_shape{3, 3}; vector pair_shape = vector(2); vector pairs_v; - cppflow::tensor coord_tensor, box_tensor, pair_tensor; + cppflow::tensor coord_tensor, box_tensor, pair_tensor, U_ind_tensor; vector output_tensors; + vector last_U_ind; vector operations; vector input_node_names = vector(3); + vector output_node_names = vector(2); OpenMM::NeighborList neighborList; vector> exclusions; int natoms; double cutoff; + bool has_aux; ENERGYTYPE dener; vector dforce; vector dcoord; diff --git a/backend/openmm_dmff_plugin/platforms/cuda/src/CudaDMFFKernels.cpp b/backend/openmm_dmff_plugin/platforms/cuda/src/CudaDMFFKernels.cpp index 7867d1f41..15857a77f 100644 --- a/backend/openmm_dmff_plugin/platforms/cuda/src/CudaDMFFKernels.cpp +++ b/backend/openmm_dmff_plugin/platforms/cuda/src/CudaDMFFKernels.cpp @@ -50,25 +50,50 @@ void CudaCalcDMFFForceKernel::initialize(const System& system, const DMFFForce& energyUnitCoeff = force.getEnergyUnitCoefficient(); coordUnitCoeff = force.getCoordUnitCoefficient(); cutoff = force.getCutoff(); + this->has_aux = force.getHasAux(); natoms = system.getNumParticles(); coord_shape[0] = natoms; coord_shape[1] = 3; exclusions.resize(natoms); + if (this->has_aux){ + U_ind_shape[0] = natoms; + U_ind_shape[1] = 3; + // Initialize the last_U_ind. + for(int ii = 0; ii < natoms * 3; ii ++){ + last_U_ind.push_back(0.0); + } + } + // Load the ordinary graph firstly. jax_model.init(graph_file); operations = jax_model.get_operations(); for (int ii = 0; ii < operations.size(); ii++){ - if (operations[ii].find("serving")!= std::string::npos){ - if (operations[ii].find("0")!= std::string::npos){ + if (operations[ii].find("serving") != std::string::npos){ + if (operations[ii].find("0") != std::string::npos){ input_node_names[0] = operations[ii] + ":0"; } else if (operations[ii].find("1") != std::string::npos){ input_node_names[1] = operations[ii] + ":0"; } else if (operations[ii].find("2") != std::string::npos){ input_node_names[2] = operations[ii] + ":0"; } + // Set up the auxilary input node name. For U_ind + if(this->has_aux){ + if (operations[ii].find("3") != std::string::npos){ + input_node_names.push_back(operations[ii] + ":0"); + } + } + } + // Set up the output names. + if (operations[ii].find("PartitionedCall") != std::string::npos){ + output_node_names[0] = operations[ii] + ":0"; + output_node_names[1] = operations[ii] + ":1"; + if(this->has_aux){ + output_node_names.push_back(operations[ii] + ":2"); + } + break; } } @@ -123,6 +148,9 @@ double CudaCalcDMFFForceKernel::execute(ContextImpl& context, bool includeForces } coord_tensor = cppflow::tensor(dcoord, coord_shape); + // Set input U_ind + U_ind_tensor = cppflow::tensor(last_U_ind, U_ind_shape); + // Fetch the neighbor list for input pairs tensor. computeNeighborListVoxelHash( neighborList, @@ -145,10 +173,27 @@ double CudaCalcDMFFForceKernel::execute(ContextImpl& context, bool includeForces pair_tensor = cppflow::tensor(pairs_v, pair_shape); // Calculate the energy and forces. - output_tensors = jax_model({{input_node_names[0], coord_tensor}, {input_node_names[1], box_tensor}, {input_node_names[2], pair_tensor}}, {"PartitionedCall:0", "PartitionedCall:1"}); - - dener = output_tensors[0].get_data()[0]; - dforce = output_tensors[1].get_data(); + if (!this->has_aux){ + output_tensors = jax_model({ + {input_node_names[0], coord_tensor}, + {input_node_names[1], box_tensor}, + {input_node_names[2], pair_tensor}}, + {output_node_names[0], output_node_names[1]}); + dener = output_tensors[0].get_data()[0]; + dforce = output_tensors[1].get_data(); + } else { + output_tensors = jax_model({ + {input_node_names[0], coord_tensor}, + {input_node_names[1], box_tensor}, + {input_node_names[2], U_ind_tensor}, + {input_node_names[3], pair_tensor}}, + {output_node_names[0], output_node_names[1], output_node_names[2]}); + + dener = output_tensors[0].get_data()[0]; + dforce = output_tensors[1].get_data(); + // Save last U_ind for next step usage. + last_U_ind = output_tensors[2].get_data(); + } // Transform the unit from eV/A to KJ/(mol*nm) diff --git a/backend/openmm_dmff_plugin/platforms/reference/include/ReferenceDMFFKernels.h b/backend/openmm_dmff_plugin/platforms/reference/include/ReferenceDMFFKernels.h index a4d5195cc..8f5850d42 100644 --- a/backend/openmm_dmff_plugin/platforms/reference/include/ReferenceDMFFKernels.h +++ b/backend/openmm_dmff_plugin/platforms/reference/include/ReferenceDMFFKernels.h @@ -75,21 +75,25 @@ class ReferenceCalcDMFFForceKernel : public CalcDMFFForceKernel { cppflow::model jax_model; std::vector coord_shape = vector(2); - std::vector coord_shape_1 = vector(2); - std::vector coord_shape_2 = vector(2); + vector U_ind_shape = vector(2); + std::vector box_shape{3, 3}; std::vector pair_shape = vector(2); std::vector output; - cppflow::tensor coord_tensor, box_tensor, pair_tensor; + cppflow::tensor coord_tensor, box_tensor, pair_tensor, U_ind_tensor; + vector output_tensors; vector operations; + vector last_U_ind; vector input_node_names = vector(3); + vector output_node_names = vector(2); OpenMM::NeighborList neighborList; std::vector> exclusions; int natoms; double cutoff; + bool has_aux; ENERGYTYPE dener; vector dforce; vector dcoord; diff --git a/backend/openmm_dmff_plugin/platforms/reference/src/ReferenceDMFFKernels.cpp b/backend/openmm_dmff_plugin/platforms/reference/src/ReferenceDMFFKernels.cpp index f73979fc8..5b3c8fd42 100644 --- a/backend/openmm_dmff_plugin/platforms/reference/src/ReferenceDMFFKernels.cpp +++ b/backend/openmm_dmff_plugin/platforms/reference/src/ReferenceDMFFKernels.cpp @@ -70,12 +70,22 @@ void ReferenceCalcDMFFForceKernel::initialize(const System& system, const DMFFFo energyUnitCoeff = force.getEnergyUnitCoefficient(); coordUnitCoeff = force.getCoordUnitCoefficient(); cutoff = force.getCutoff(); + this->has_aux = force.getHasAux(); natoms = system.getNumParticles(); coord_shape[0] = natoms; coord_shape[1] = 3; exclusions.resize(natoms); + if (this->has_aux){ + U_ind_shape[0] = natoms; + U_ind_shape[1] = 3; + // Initialize the last_U_ind. + for(int ii = 0; ii < natoms * 3; ii ++){ + last_U_ind.push_back(0.0); + } + } + // Load the ordinary graph firstly. jax_model.init(graph_file); @@ -88,9 +98,22 @@ void ReferenceCalcDMFFForceKernel::initialize(const System& system, const DMFFFo input_node_names[1] = operations[ii]+":0"; } else if (operations[ii].find("2") != std::string::npos){ input_node_names[2] = operations[ii]+":0"; - } else { - std::cout << "Warning: Unknown input node name: " << operations[ii] << std::endl; + } + // Set up the auxilary input node name. For U_ind + if(this->has_aux){ + if (operations[ii].find("3") != std::string::npos){ + input_node_names.push_back(operations[ii] + ":0"); + } + } + } + // Set up the output names. + if (operations[ii].find("PartitionedCall") != std::string::npos){ + output_node_names[0] = operations[ii] + ":0"; + output_node_names[1] = operations[ii] + ":1"; + if(this->has_aux){ + output_node_names.push_back(operations[ii] + ":2"); } + break; } } @@ -134,6 +157,9 @@ double ReferenceCalcDMFFForceKernel::execute(ContextImpl& context, bool includeF } coord_tensor = cppflow::tensor(dcoord, coord_shape); + // Set input U_ind. + U_ind_tensor = cppflow::tensor(last_U_ind, U_ind_shape); + // Set input pairs. computeNeighborListVoxelHash( neighborList, @@ -156,12 +182,30 @@ double ReferenceCalcDMFFForceKernel::execute(ContextImpl& context, bool includeF } pair_shape[0] = totpairs; pair_shape[1] = 2; - cppflow::tensor pair_tensor = cppflow::tensor(dpairs, pair_shape); - - output = jax_model({{input_node_names[0], coord_tensor}, {input_node_names[1], box_tensor}, {input_node_names[2], pair_tensor}}, {"PartitionedCall:0", "PartitionedCall:1"}); + pair_tensor = cppflow::tensor(dpairs, pair_shape); + + if (!this->has_aux){ + output_tensors = jax_model({ + {input_node_names[0], coord_tensor}, + {input_node_names[1], box_tensor}, + {input_node_names[2], pair_tensor}}, + {output_node_names[0], output_node_names[1]}); + dener = output_tensors[0].get_data()[0]; + dforce = output_tensors[1].get_data(); + } else { + output_tensors = jax_model({ + {input_node_names[0], coord_tensor}, + {input_node_names[1], box_tensor}, + {input_node_names[2], U_ind_tensor}, + {input_node_names[3], pair_tensor}}, + {output_node_names[0], output_node_names[1], output_node_names[2]}); + + dener = output_tensors[0].get_data()[0]; + dforce = output_tensors[1].get_data(); + // Save last U_ind for next step usage. + last_U_ind = output_tensors[2].get_data(); + } - dener = output[0].get_data()[0]; - dforce = output[1].get_data(); // Transform the unit from output units to KJ/(mol*nm) for(int ii = 0; ii < natoms; ii ++){ diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i index 84f3ab021..6bd7bccb5 100644 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i +++ b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i @@ -47,6 +47,7 @@ public: DMFFForce(const string& GraphFile); void setUnitTransformCoefficients(const double coordCoefficient, const double forceCoefficient, const double energyCoefficient); + void setHasAux(const bool hasAux); /* * Add methods for casting a Force to a DMFFForce. diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/fingerprint.pb b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/fingerprint.pb new file mode 100644 index 000000000..ec62a12f0 --- /dev/null +++ b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/fingerprint.pb @@ -0,0 +1 @@ +‘Þжµ¾“¯ÓãÞúØ“žÂæ÷ûÔ¬Ö°Ô÷® ³Ž¢”šŽÿ¶-(°—¶úÛ„ÿƒ2 \ No newline at end of file diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/saved_model.pb b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/saved_model.pb new file mode 100644 index 000000000..7826cccab Binary files /dev/null and b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/saved_model.pb differ diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/variables/variables.data-00000-of-00001 b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/variables/variables.data-00000-of-00001 similarity index 100% rename from backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/variables/variables.data-00000-of-00001 rename to backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/variables/variables.data-00000-of-00001 diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/variables/variables.index b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/variables/variables.index similarity index 100% rename from backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/variables/variables.index rename to backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/admp_water_dimer_aux/variables/variables.index diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid.pdb b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid.pdb deleted file mode 100644 index bdbca18eb..000000000 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid.pdb +++ /dev/null @@ -1,204 +0,0 @@ -REMARK 1 CREATED WITH OPENMM 7.7, 2022-11-03 -CRYST1 24.413 24.413 24.413 90.00 90.00 90.00 P 1 1 -HETATM 1 ATM LJP A 1 14.869 14.417 18.370 1.00 0.00 O -HETATM 2 ATM LJP A 2 4.282 12.164 23.527 1.00 0.00 O -HETATM 3 ATM LJP A 3 2.450 14.031 8.695 1.00 0.00 O -HETATM 4 ATM LJP A 4 16.580 4.821 10.031 1.00 0.00 O -HETATM 5 ATM LJP A 5 11.704 14.292 0.907 1.00 0.00 O -HETATM 6 ATM LJP A 6 23.001 12.364 7.593 1.00 0.00 O -HETATM 7 ATM LJP A 7 23.508 15.238 15.395 1.00 0.00 O -HETATM 8 ATM LJP A 8 6.515 12.797 11.300 1.00 0.00 O -HETATM 9 ATM LJP A 9 13.238 10.387 6.028 1.00 0.00 O -HETATM 10 ATM LJP A 10 10.313 23.941 11.909 1.00 0.00 O -HETATM 11 ATM LJP A 11 13.783 22.266 15.460 1.00 0.00 O -HETATM 12 ATM LJP A 12 21.957 10.599 14.550 1.00 0.00 O -HETATM 13 ATM LJP A 13 1.153 14.707 21.146 1.00 0.00 O -HETATM 14 ATM LJP A 14 3.065 5.774 18.188 1.00 0.00 O -HETATM 15 ATM LJP A 15 -0.494 18.320 23.497 1.00 0.00 O -HETATM 16 ATM LJP A 16 18.179 11.692 20.008 1.00 0.00 O -HETATM 17 ATM LJP A 17 18.654 3.430 18.725 1.00 0.00 O -HETATM 18 ATM LJP A 18 4.292 15.221 14.368 1.00 0.00 O -HETATM 19 ATM LJP A 19 13.570 20.981 11.030 1.00 0.00 O -HETATM 20 ATM LJP A 20 14.098 2.258 3.263 1.00 0.00 O -HETATM 21 ATM LJP A 21 2.123 9.242 21.351 1.00 0.00 O -HETATM 22 ATM LJP A 22 12.520 6.546 15.450 1.00 0.00 O -HETATM 23 ATM LJP A 23 24.359 14.137 0.991 1.00 0.00 O -HETATM 24 ATM LJP A 24 6.620 13.891 2.446 1.00 0.00 O -HETATM 25 ATM LJP A 25 21.807 10.765 3.568 1.00 0.00 O -HETATM 26 ATM LJP A 26 10.917 7.743 22.414 1.00 0.00 O -HETATM 27 ATM LJP A 27 4.378 21.796 9.539 1.00 0.00 O -HETATM 28 ATM LJP A 28 12.263 16.801 11.360 1.00 0.00 O -HETATM 29 ATM LJP A 29 23.466 15.591 11.030 1.00 0.00 O -HETATM 30 ATM LJP A 30 12.570 11.877 21.640 1.00 0.00 O -HETATM 31 ATM LJP A 31 4.957 16.580 5.198 1.00 0.00 O -HETATM 32 ATM LJP A 32 13.144 14.976 7.143 1.00 0.00 O -HETATM 33 ATM LJP A 33 10.516 0.978 16.482 1.00 0.00 O -HETATM 34 ATM LJP A 34 21.533 22.135 21.415 1.00 0.00 O -HETATM 35 ATM LJP A 35 19.163 15.897 12.458 1.00 0.00 O -HETATM 36 ATM LJP A 36 18.677 8.567 17.155 1.00 0.00 O -HETATM 37 ATM LJP A 37 21.512 5.445 15.760 1.00 0.00 O -HETATM 38 ATM LJP A 38 15.814 19.201 17.932 1.00 0.00 O -HETATM 39 ATM LJP A 39 19.875 7.042 21.085 1.00 0.00 O -HETATM 40 ATM LJP A 40 18.557 18.430 21.122 1.00 0.00 O -HETATM 41 ATM LJP A 41 19.743 20.838 17.328 1.00 0.00 O -HETATM 42 ATM LJP A 42 14.769 20.688 23.225 1.00 0.00 O -HETATM 43 ATM LJP A 43 1.588 18.634 13.100 1.00 0.00 O -HETATM 44 ATM LJP A 44 19.523 14.241 2.902 1.00 0.00 O -HETATM 45 ATM LJP A 45 17.763 12.461 15.118 1.00 0.00 O -HETATM 46 ATM LJP A 46 22.309 6.424 4.232 1.00 0.00 O -HETATM 47 ATM LJP A 47 20.509 1.972 22.418 1.00 0.00 O -HETATM 48 ATM LJP A 48 7.959 22.298 18.865 1.00 0.00 O -HETATM 49 ATM LJP A 49 6.643 24.145 14.313 1.00 0.00 O -HETATM 50 ATM LJP A 50 9.792 12.498 5.067 1.00 0.00 O -HETATM 51 ATM LJP A 51 11.904 17.758 16.664 1.00 0.00 O -HETATM 52 ATM LJP A 52 2.970 4.565 22.786 1.00 0.00 O -HETATM 53 ATM LJP A 53 9.821 18.804 20.610 1.00 0.00 O -HETATM 54 ATM LJP A 54 2.198 12.162 17.406 1.00 0.00 O -HETATM 55 ATM LJP A 55 1.378 1.044 9.499 1.00 0.00 O -HETATM 56 ATM LJP A 56 2.039 5.397 10.388 1.00 0.00 O -HETATM 57 ATM LJP A 57 18.989 16.082 17.350 1.00 0.00 O -HETATM 58 ATM LJP A 58 1.860 18.321 8.019 1.00 0.00 O -HETATM 59 ATM LJP A 59 8.502 3.188 24.162 1.00 0.00 O -HETATM 60 ATM LJP A 60 20.214 8.935 7.367 1.00 0.00 O -HETATM 61 ATM LJP A 61 21.347 23.260 13.818 1.00 0.00 O -HETATM 62 ATM LJP A 62 9.940 2.096 4.845 1.00 0.00 O -HETATM 63 ATM LJP A 63 2.175 23.638 13.552 1.00 0.00 O -HETATM 64 ATM LJP A 64 10.178 0.875 21.046 1.00 0.00 O -HETATM 65 ATM LJP A 65 2.683 1.509 2.312 1.00 0.00 O -HETATM 66 ATM LJP A 66 4.980 19.023 21.448 1.00 0.00 O -HETATM 67 ATM LJP A 67 12.019 6.935 10.732 1.00 0.00 O -HETATM 68 ATM LJP A 68 24.222 21.601 10.460 1.00 0.00 O -HETATM 69 ATM LJP A 69 15.106 9.357 13.374 1.00 0.00 O -HETATM 70 ATM LJP A 70 17.486 0.001 15.913 1.00 0.00 O -HETATM 71 ATM LJP A 71 13.398 16.791 21.634 1.00 0.00 O -HETATM 72 ATM LJP A 72 3.709 9.591 9.917 1.00 0.00 O -HETATM 73 ATM LJP A 73 19.379 7.608 12.121 1.00 0.00 O -HETATM 74 ATM LJP A 74 2.507 12.747 4.288 1.00 0.00 O -HETATM 75 ATM LJP A 75 23.371 8.711 18.224 1.00 0.00 O -HETATM 76 ATM LJP A 76 11.850 19.267 7.294 1.00 0.00 O -HETATM 77 ATM LJP A 77 7.635 15.939 23.087 1.00 0.00 O -HETATM 78 ATM LJP A 78 5.569 8.128 23.936 1.00 0.00 O -HETATM 79 ATM LJP A 79 15.107 6.210 18.996 1.00 0.00 O -HETATM 80 ATM LJP A 80 2.611 17.603 17.589 1.00 0.00 O -HETATM 81 ATM LJP A 81 8.151 8.802 9.716 1.00 0.00 O -HETATM 82 ATM LJP A 82 10.201 21.419 4.099 1.00 0.00 O -HETATM 83 ATM LJP A 83 16.098 13.719 23.480 1.00 0.00 O -HETATM 84 ATM LJP A 84 8.929 15.590 8.311 1.00 0.00 O -HETATM 85 ATM LJP A 85 13.937 2.387 14.025 1.00 0.00 O -HETATM 86 ATM LJP A 86 12.885 2.555 9.979 1.00 0.00 O -HETATM 87 ATM LJP A 87 5.137 9.361 3.974 1.00 0.00 O -HETATM 88 ATM LJP A 88 1.281 8.620 6.312 1.00 0.00 O -HETATM 89 ATM LJP A 89 17.554 12.711 6.960 1.00 0.00 O -HETATM 90 ATM LJP A 90 15.184 22.293 3.469 1.00 0.00 O -HETATM 91 ATM LJP A 91 23.319 4.435 19.790 1.00 0.00 O -HETATM 92 ATM LJP A 92 10.995 5.888 2.383 1.00 0.00 O -HETATM 93 ATM LJP A 93 0.459 0.884 22.377 1.00 0.00 O -HETATM 94 ATM LJP A 94 7.851 22.165 23.288 1.00 0.00 O -HETATM 95 ATM LJP A 95 16.031 9.092 9.029 1.00 0.00 O -HETATM 96 ATM LJP A 96 15.514 13.294 10.917 1.00 0.00 O -HETATM 97 ATM LJP A 97 8.621 16.037 13.610 1.00 0.00 O -HETATM 98 ATM LJP A 98 13.277 5.452 6.697 1.00 0.00 O -HETATM 99 ATM LJP A 99 7.398 12.445 15.919 1.00 0.00 O -HETATM 100 ATM LJP A 100 1.233 9.811 1.521 1.00 0.00 O -HETATM 101 ATM LJP A 101 17.182 9.617 4.050 1.00 0.00 O -HETATM 102 ATM LJP A 102 23.810 15.860 5.104 1.00 0.00 O -HETATM 103 ATM LJP A 103 6.341 19.363 1.958 1.00 0.00 O -HETATM 104 ATM LJP A 104 4.815 9.375 14.548 1.00 0.00 O -HETATM 105 ATM LJP A 105 6.653 5.055 3.047 1.00 0.00 O -HETATM 106 ATM LJP A 106 20.997 18.672 2.168 1.00 0.00 O -HETATM 107 ATM LJP A 107 19.650 11.833 10.909 1.00 0.00 O -HETATM 108 ATM LJP A 108 13.763 18.358 3.334 1.00 0.00 O -HETATM 109 ATM LJP A 109 6.167 5.750 12.102 1.00 0.00 O -HETATM 110 ATM LJP A 110 21.996 13.109 19.009 1.00 0.00 O -HETATM 111 ATM LJP A 111 5.614 1.795 10.621 1.00 0.00 O -HETATM 112 ATM LJP A 112 15.168 6.135 1.697 1.00 0.00 O -HETATM 113 ATM LJP A 113 9.818 3.721 13.101 1.00 0.00 O -HETATM 114 ATM LJP A 114 16.586 21.465 7.700 1.00 0.00 O -HETATM 115 ATM LJP A 115 12.604 22.049 19.687 1.00 0.00 O -HETATM 116 ATM LJP A 116 19.338 6.234 0.946 1.00 0.00 O -HETATM 117 ATM LJP A 117 21.932 5.114 9.361 1.00 0.00 O -HETATM 118 ATM LJP A 118 5.063 0.862 21.897 1.00 0.00 O -HETATM 119 ATM LJP A 119 4.957 2.078 6.381 1.00 0.00 O -HETATM 120 ATM LJP A 120 2.329 5.797 2.647 1.00 0.00 O -HETATM 121 ATM LJP A 121 1.591 12.245 12.844 1.00 0.00 O -HETATM 122 ATM LJP A 122 16.597 24.419 11.129 1.00 0.00 O -HETATM 123 ATM LJP A 123 12.369 13.273 14.671 1.00 0.00 O -HETATM 124 ATM LJP A 124 5.469 19.987 13.801 1.00 0.00 O -HETATM 125 ATM LJP A 125 19.790 22.589 4.699 1.00 0.00 O -HETATM 126 ATM LJP A 126 18.303 2.103 2.180 1.00 0.00 O -HETATM 127 ATM LJP A 127 9.397 16.920 4.039 1.00 0.00 O -HETATM 128 ATM LJP A 128 14.287 1.847 18.767 1.00 0.00 O -HETATM 129 ATM LJP A 129 15.318 14.136 3.267 1.00 0.00 O -HETATM 130 ATM LJP A 130 3.819 22.102 5.225 1.00 0.00 O -HETATM 131 ATM LJP A 131 9.017 9.943 1.343 1.00 0.00 O -HETATM 132 ATM LJP A 132 0.959 7.195 14.226 1.00 0.00 O -HETATM 133 ATM LJP A 133 18.988 10.612 0.110 1.00 0.00 O -HETATM 134 ATM LJP A 134 5.231 16.952 10.156 1.00 0.00 O -HETATM 135 ATM LJP A 135 1.581 1.397 17.762 1.00 0.00 O -HETATM 136 ATM LJP A 136 13.682 10.290 1.707 1.00 0.00 O -HETATM 137 ATM LJP A 137 5.629 9.455 18.879 1.00 0.00 O -HETATM 138 ATM LJP A 138 8.945 19.775 10.880 1.00 0.00 O -HETATM 139 ATM LJP A 139 15.278 16.450 14.302 1.00 0.00 O -HETATM 140 ATM LJP A 140 11.055 4.757 19.054 1.00 0.00 O -HETATM 141 ATM LJP A 141 6.253 2.027 17.903 1.00 0.00 O -HETATM 142 ATM LJP A 142 7.812 6.088 16.250 1.00 0.00 O -HETATM 143 ATM LJP A 143 16.886 5.251 15.066 1.00 0.00 O -HETATM 144 ATM LJP A 144 20.992 24.409 9.244 1.00 0.00 O -HETATM 145 ATM LJP A 145 12.841 23.567 6.889 1.00 0.00 O -HETATM 146 ATM LJP A 146 9.853 9.608 13.817 1.00 0.00 O -HETATM 147 ATM LJP A 147 6.080 12.030 6.994 1.00 0.00 O -HETATM 148 ATM LJP A 148 16.666 1.590 6.957 1.00 0.00 O -HETATM 149 ATM LJP A 149 5.502 14.368 19.487 1.00 0.00 O -HETATM 150 ATM LJP A 150 17.292 20.610 13.492 1.00 0.00 O -HETATM 151 ATM LJP A 151 12.589 3.868 23.253 1.00 0.00 O -HETATM 152 ATM LJP A 152 17.936 18.176 5.237 1.00 0.00 O -HETATM 153 ATM LJP A 153 3.990 21.877 17.484 1.00 0.00 O -HETATM 154 ATM LJP A 154 17.109 17.693 0.653 1.00 0.00 O -HETATM 155 ATM LJP A 155 1.459 21.296 20.946 1.00 0.00 O -HETATM 156 ATM LJP A 156 3.396 22.007 0.400 1.00 0.00 O -HETATM 157 ATM LJP A 157 9.349 7.859 5.703 1.00 0.00 O -HETATM 158 ATM LJP A 158 22.681 10.914 22.751 1.00 0.00 O -HETATM 159 ATM LJP A 159 22.311 19.684 6.532 1.00 0.00 O -HETATM 160 ATM LJP A 160 7.358 19.657 6.830 1.00 0.00 O -HETATM 161 ATM LJP A 161 8.435 23.908 7.913 1.00 0.00 O -HETATM 162 ATM LJP A 162 22.823 18.032 19.242 1.00 0.00 O -HETATM 163 ATM LJP A 163 6.987 24.374 2.751 1.00 0.00 O -HETATM 164 ATM LJP A 164 1.238 19.288 3.152 1.00 0.00 O -HETATM 165 ATM LJP A 165 3.766 3.412 14.505 1.00 0.00 O -HETATM 166 ATM LJP A 166 17.479 23.245 20.271 1.00 0.00 O -HETATM 167 ATM LJP A 167 8.131 11.490 21.734 1.00 0.00 O -HETATM 168 ATM LJP A 168 10.833 18.968 0.392 1.00 0.00 O -HETATM 169 ATM LJP A 169 23.826 2.848 13.407 1.00 0.00 O -HETATM 170 ATM LJP A 170 20.369 15.881 8.055 1.00 0.00 O -HETATM 171 ATM LJP A 171 20.751 2.470 5.736 1.00 0.00 O -HETATM 172 ATM LJP A 172 11.290 24.050 1.080 1.00 0.00 O -HETATM 173 ATM LJP A 173 21.737 0.532 18.006 1.00 0.00 O -HETATM 174 ATM LJP A 174 15.141 0.407 23.854 1.00 0.00 O -HETATM 175 ATM LJP A 175 24.202 21.236 16.792 1.00 0.00 O -HETATM 176 ATM LJP A 176 9.624 20.857 15.029 1.00 0.00 O -HETATM 177 ATM LJP A 177 10.552 14.344 18.936 1.00 0.00 O -HETATM 178 ATM LJP A 178 23.468 6.714 23.607 1.00 0.00 O -HETATM 179 ATM LJP A 179 16.348 17.348 9.441 1.00 0.00 O -HETATM 180 ATM LJP A 180 19.555 2.774 12.697 1.00 0.00 O -HETATM 181 ATM LJP A 181 23.304 22.686 1.339 1.00 0.00 O -HETATM 182 ATM LJP A 182 15.751 8.855 22.264 1.00 0.00 O -HETATM 183 ATM LJP A 183 17.631 5.546 5.404 1.00 0.00 O -HETATM 184 ATM LJP A 184 23.537 8.901 10.565 1.00 0.00 O -HETATM 185 ATM LJP A 185 14.367 10.164 17.737 1.00 0.00 O -HETATM 186 ATM LJP A 186 10.200 9.386 18.329 1.00 0.00 O -HETATM 187 ATM LJP A 187 8.833 4.201 8.357 1.00 0.00 O -HETATM 188 ATM LJP A 188 24.159 23.893 5.675 1.00 0.00 O -HETATM 189 ATM LJP A 189 21.608 18.838 14.112 1.00 0.00 O -HETATM 190 ATM LJP A 190 7.187 17.854 17.442 1.00 0.00 O -HETATM 191 ATM LJP A 191 3.460 16.451 0.423 1.00 0.00 O -HETATM 192 ATM LJP A 192 18.826 22.220 0.523 1.00 0.00 O -HETATM 193 ATM LJP A 193 22.824 2.641 1.749 1.00 0.00 O -HETATM 194 ATM LJP A 194 16.754 4.199 22.403 1.00 0.00 O -HETATM 195 ATM LJP A 195 0.853 3.877 6.189 1.00 0.00 O -HETATM 196 ATM LJP A 196 7.160 5.518 20.741 1.00 0.00 O -HETATM 197 ATM LJP A 197 20.256 15.051 22.686 1.00 0.00 O -HETATM 198 ATM LJP A 198 4.934 6.245 7.081 1.00 0.00 O -HETATM 199 ATM LJP A 199 20.069 20.260 10.139 1.00 0.00 O -HETATM 200 ATM LJP A 200 11.265 12.039 10.134 1.00 0.00 O -TER 201 LJP A 200 -END diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/fingerprint.pb b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/fingerprint.pb deleted file mode 100644 index a71b9d584..000000000 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/fingerprint.pb +++ /dev/null @@ -1 +0,0 @@ -€ð„ç­éûœ ÁÓŠ¡Ý–²³Ž¤çÉãÛ†Óq ¿Ñ™ùî¿”(°—¶úÛ„ÿƒ2 \ No newline at end of file diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/saved_model.pb b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/saved_model.pb deleted file mode 100644 index c1a96e44a..000000000 Binary files a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/data/lj_fluid_gpu/saved_model.pb and /dev/null differ diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tests/test_dmff_plugin_nve.py b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tests/test_dmff_plugin_nve.py index dd6dd53e5..d370efc62 100644 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tests/test_dmff_plugin_nve.py +++ b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tests/test_dmff_plugin_nve.py @@ -15,24 +15,20 @@ from OpenMMDMFFPlugin import DMFFModel -def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", output_temp_dir = "/tmp/openmm_dmff_plugin_test_nve_output", energy_std_tol = 0.0005 ): +def test_dmff_nve(nsteps = 1000, time_step = 0.2, pdb_file = None, model_dir = None, platform_name = "Reference", output_temp_dir = "/tmp/openmm_dmff_plugin_test_nve_output", energy_std_tol = 0.005, has_aux = False ): if not os.path.exists(output_temp_dir): os.mkdir(output_temp_dir) - pdb_file = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid.pdb") - if platform_name == "Reference": - dmff_model_file = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid_gpu") - elif platform_name == "CUDA": - dmff_model_file = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid_gpu") - - output_dcd = os.path.join(output_temp_dir, "lj_fluid_test.nve.dcd") - output_log = os.path.join(output_temp_dir, "lj_fluid_test.nve.log") + dmff_model_file = model_dir + + output_dcd = os.path.join(output_temp_dir, "test.nve.dcd") + output_log = os.path.join(output_temp_dir, "test.nve.log") # Set up the simulation parameters. nsteps = nsteps time_step = time_step # unit is femtosecond. - report_frequency = 10 - box = [24.413, 0, 0, 0, 24.413, 0, 0, 0, 24.413] + report_frequency = 1 + box = [31.289, 0, 0, 0, 31.289, 0, 0, 0, 31.289] box = [mm.Vec3(box[0], box[1], box[2]), mm.Vec3(box[3], box[4], box[5]), mm.Vec3(box[6], box[7], box[8])] * u.angstroms liquid_water = PDBFile(pdb_file) @@ -43,6 +39,8 @@ def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", o # Set up the dmff_system with the dmff_model. dmff_model = DMFFModel(dmff_model_file) dmff_model.setUnitTransformCoefficients(1, 1, 1) + if has_aux: + dmff_model.setHasAux(True) dmff_system = dmff_model.createSystem(topology) integrator = mm.VerletIntegrator(time_step*u.femtoseconds) @@ -83,13 +81,16 @@ def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", o # Check the total energy fluctuations over # of atoms is smaller than energy_std_tol, unit in kJ/mol. print("Total energy std: %.4f kJ/mol"%(np.std(total_energy))) print("Mean total energy: %.4f kJ/mol"%(np.mean(total_energy))) - #assert(np.std(total_energy) / num_atoms < energy_std_tol) + assert(np.std(total_energy) / num_atoms < energy_std_tol) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-n', '--nsteps', type = int, dest='nsteps', help='Number of steps', default=100) parser.add_argument('--dt', type = float, dest='timestep', help='Time step for simulation, unit is femtosecond', default=0.2) + parser.add_argument('--pdb', type = str, dest='pdb', help='PDB file for simulation.', default=None) + parser.add_argument('--model', type = str, dest='model', help='DMFF model dir for simulation. Saved by backend/save_dmff2tf.py.', default=None) parser.add_argument('--platform', type = str, dest='platform', help='Platform for simulation.', default="Reference") + parser.add_argument('--has_aux', type = bool, dest='has_aux', help='Whether the model has aux output.', default=False) args = parser.parse_args() @@ -97,5 +98,13 @@ def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", o time_step = args.timestep platform_name = args.platform - test_dmff_nve(nsteps=nsteps, time_step=time_step, platform_name=platform_name) + pdb = args.pdb + model_dir = args.model + + if pdb is None: + pdb = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid.pdb") + if model_dir is None: + model_dir = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid_gpu") + + test_dmff_nve(nsteps=nsteps, time_step=time_step, pdb_file=pdb, model_dir=model_dir, platform_name=platform_name, has_aux=args.has_aux) diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py index d86804eaf..7726b3210 100644 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py +++ b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py @@ -70,6 +70,15 @@ def setUnitTransformCoefficients(self, coordinatesCoefficient, forceCoefficient, self.dmff_force.setUnitTransformCoefficients(coordinatesCoefficient, forceCoefficient, energyCoefficient) return + def setHasAux(self, has_aux = False): + """Set whether the DMFF model has auxilary output. + Used when model was saved with has_aux = True. + + Args: + has_aux (bool, optional): Defaults to False. + """ + self.dmff_force.setHasAux(has_aux) + def createSystem(self, topology): """Create the OpenMM System object for the DMFF model. diff --git a/backend/openmm_dmff_plugin/python/tests/test_dmff_plugin_nve.py b/backend/openmm_dmff_plugin/python/tests/test_dmff_plugin_nve.py index 21a16ca79..630114c76 100644 --- a/backend/openmm_dmff_plugin/python/tests/test_dmff_plugin_nve.py +++ b/backend/openmm_dmff_plugin/python/tests/test_dmff_plugin_nve.py @@ -15,18 +15,14 @@ from OpenMMDMFFPlugin import DMFFModel -def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", output_temp_dir = "/tmp/openmm_dmff_plugin_test_nve_output", energy_std_tol = 0.005 ): +def test_dmff_nve(nsteps = 1000, time_step = 0.2, pdb_file = None, model_dir = None, platform_name = "Reference", output_temp_dir = "/tmp/openmm_dmff_plugin_test_nve_output", energy_std_tol = 0.005, has_aux = False ): if not os.path.exists(output_temp_dir): os.mkdir(output_temp_dir) - pdb_file = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid.pdb") - if platform_name == "Reference": - dmff_model_file = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid_gpu") - elif platform_name == "CUDA": - dmff_model_file = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid_gpu") - - output_dcd = os.path.join(output_temp_dir, "lj_fluid_test.nve.dcd") - output_log = os.path.join(output_temp_dir, "lj_fluid_test.nve.log") + dmff_model_file = model_dir + + output_dcd = os.path.join(output_temp_dir, "test.nve.dcd") + output_log = os.path.join(output_temp_dir, "test.nve.log") # Set up the simulation parameters. nsteps = nsteps @@ -43,6 +39,8 @@ def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", o # Set up the dmff_system with the dmff_model. dmff_model = DMFFModel(dmff_model_file) dmff_model.setUnitTransformCoefficients(1, 1, 1) + if has_aux: + dmff_model.setHasAux() dmff_system = dmff_model.createSystem(topology) integrator = mm.VerletIntegrator(time_step*u.femtoseconds) @@ -89,7 +87,10 @@ def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", o parser = argparse.ArgumentParser() parser.add_argument('-n', '--nsteps', type = int, dest='nsteps', help='Number of steps', default=100) parser.add_argument('--dt', type = float, dest='timestep', help='Time step for simulation, unit is femtosecond', default=0.2) + parser.add_argument('--pdb', type = str, dest='pdb', help='PDB file for simulation.', default=None) + parser.add_argument('--model', type = str, dest='model', help='DMFF model dir for simulation. Saved by backend/save_dmff2tf.py.', default=None) parser.add_argument('--platform', type = str, dest='platform', help='Platform for simulation.', default="Reference") + parser.add_argument('--has_aux', type = bool, dest='has_aux', help='Whether the model has aux output.', default=False) args = parser.parse_args() @@ -97,5 +98,13 @@ def test_dmff_nve(nsteps = 1000, time_step = 0.2, platform_name = "Reference", o time_step = args.timestep platform_name = args.platform - test_dmff_nve(nsteps=nsteps, time_step=time_step, platform_name=platform_name) + pdb = args.pdb + model_dir = args.model + + if pdb is None: + pdb = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid.pdb") + if model_dir is None: + model_dir = os.path.join(os.path.dirname(__file__), "../data", "lj_fluid_gpu") + + test_dmff_nve(nsteps=nsteps, time_step=time_step, pdb_file=pdb, model_dir=model_dir, platform_name=platform_name) diff --git a/backend/save_dmff2tf.py b/backend/save_dmff2tf.py index fac1741d3..2afea0fa3 100644 --- a/backend/save_dmff2tf.py +++ b/backend/save_dmff2tf.py @@ -17,29 +17,55 @@ for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) -def create_dmff_potential(input_pdb_file, ff_xml_files): +def create_dmff_potential(input_pdb_file, ff_xml_files, bond_definitions_xml = None, has_aux = False): pdb = app.PDBFile(input_pdb_file) h = dmff.Hamiltonian(*ff_xml_files) - pot = h.createPotential(pdb.topology, + if bond_definitions_xml is not None: + app.Topology.loadBondDefinitions(bond_definitions_xml) + + if has_aux:# Used when using ADMP with DMFF. + pot = h.createPotential(pdb.topology, + nonbondedMethod=app.PME, + ethresh=5e-4, step_pol=10, + nonbondedCutoff=1.2 * + unit.nanometer, + has_aux=True) + else: + pot = h.createPotential(pdb.topology, nonbondedMethod=app.PME, nonbondedCutoff=1.2 * unit.nanometer) + pot_func = pot.getPotentialFunc() a, b, c = pdb.topology.getPeriodicBoxVectors() a = a.value_in_unit(unit.nanometer) b = b.value_in_unit(unit.nanometer) c = c.value_in_unit(unit.nanometer) - engrad = jax.value_and_grad(pot_func, 0) + if has_aux: + engrad = jax.value_and_grad(pot_func, 0, has_aux=True) + else: + engrad = jax.value_and_grad(pot_func, 0) - covalent_map = h.getGenerators()[-1].covalent_map + covalent_map = pot.meta["cov_map"] + aux = dict() - def potential_engrad(positions, box, pairs): - if jnp.shape(pairs)[-1] == 2: - nbond = covalent_map[pairs[:, 0], pairs[:, 1]] - pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1) - - return engrad(positions, box, pairs, h.paramtree) + if has_aux: + def potential_engrad(positions, box, U_ind, pairs): + if jnp.shape(pairs)[-1] == 2: + nbond = covalent_map[pairs[:, 0], pairs[:, 1]] + pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1) + aux['U_ind'] = U_ind + ener_and_aux, ener_grad = engrad(positions, box, pairs, h.getParameters(), aux) + # Return energy, gradient (forces), and U_ind + return ener_and_aux[0], ener_grad, ener_and_aux[1]['U_ind'] + else: + def potential_engrad(positions, box, pairs): + if jnp.shape(pairs)[-1] == 2: + nbond = covalent_map[pairs[:, 0], pairs[:, 1]] + pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1) + + return engrad(positions, box, pairs, h.getParameters()) return pdb, potential_engrad, covalent_map, pot, h @@ -49,26 +75,49 @@ def potential_engrad(positions, box, pairs): parser.add_argument("--input_pdb", dest="input_pdb", help="input pdb file. Box information is required in the pdb file.") parser.add_argument("--xml_files", dest="xml_files", nargs="+", help=".xml files with parameters are derived from DMFF.") parser.add_argument("--output", dest="output", help="output directory") + parser.add_argument("--bond_definitions_xml", dest="bond_definitions_xml", help=".xml file that contains bond definitions. Optional", default=None) + parser.add_argument("--has_aux", dest="has_aux", default=False, help="Enable aux output in the model. Used when using ADMP with DMFF, and the output would be U_ind.") args = parser.parse_args() input_pdb = args.input_pdb ff_xml_files = args.xml_files output_dir = args.output + has_aux = args.has_aux + bond_definitions_xml = args.bond_definitions_xml + if output_dir[-1] == "/": output_dir = output_dir[:-1] if not os.path.exists(output_dir): os.mkdir(output_dir) - pdb, pot_grad, covalent_map, pot, h = create_dmff_potential(input_pdb, ff_xml_files) + pdb, pot_grad, covalent_map, pot, h = create_dmff_potential(input_pdb, ff_xml_files, bond_definitions_xml=bond_definitions_xml, has_aux=has_aux) natoms = pdb.getTopology().getNumAtoms() - f_tf = jax2tf.convert( - jax.jit(pot_grad), - polymorphic_shapes=["("+str(natoms)+", 3)", "(3, 3)", "(b, 2)"] - ) + if has_aux: + f_tf = jax2tf.convert( + jax.jit(pot_grad), + polymorphic_shapes=["("+str(natoms)+", 3)", "(3, 3)", "("+str(natoms)+", 3)", "(b, 2)"] + ) + else: + f_tf = jax2tf.convert( + jax.jit(pot_grad), + polymorphic_shapes=["("+str(natoms)+", 3)", "(3, 3)", "(b, 2)"] + ) + dmff_model = tf.Module() - dmff_model.f = tf.function(f_tf, autograph=False, - input_signature=[tf.TensorSpec(shape=[natoms,3], dtype=tf.float32), tf.TensorSpec(shape=[3,3], dtype=tf.float32), tf.TensorSpec(shape=tf.TensorShape([None, 2]), dtype=tf.int32)]) + if has_aux: + dmff_model.f = tf.function(f_tf, autograph=False, + input_signature=[ + tf.TensorSpec(shape=[natoms,3], dtype=tf.float32), + tf.TensorSpec(shape=[3,3], dtype=tf.float32), + tf.TensorSpec(shape=[natoms,3], dtype=tf.float64), + tf.TensorSpec(shape=tf.TensorShape([None, 2]), dtype=tf.int32)]) + else: + dmff_model.f = tf.function(f_tf, autograph=False, + input_signature=[ + tf.TensorSpec(shape=[natoms,3], dtype=tf.float32), tf.TensorSpec(shape=[3,3], dtype=tf.float32), + tf.TensorSpec(shape=tf.TensorShape([None, 2]), dtype=tf.int32) + ]) tf.saved_model.save(dmff_model, output_dir, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))