diff --git a/source/api_cc/tests/test_deeppot_jax.cc b/source/api_cc/tests/test_deeppot_jax.cc index 0ca04b15ee..6af65e865b 100644 --- a/source/api_cc/tests/test_deeppot_jax.cc +++ b/source/api_cc/tests/test_deeppot_jax.cc @@ -35,37 +35,33 @@ class TestInferDeepPotAJAX : public ::testing::Test { 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; std::vector atype = {0, 1, 1, 0, 1, 1}; std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + // the data in this file is just copied from PT std::vector expected_e = { - -94.81857726831211, -186.84152093802385, -186.74187382823973, - -95.14795884356523, -186.63980359613632, -186.7289280556596}; + -93.016873944029, -185.923296645958, -185.927096544970, + -93.019371018039, -185.926179995548, -185.924351901852}; std::vector expected_f = { - 0.0477148189212569, 1.4043255476951888, -0.4721020550457472, - 1.0407593579417451, 0.0648923873554379, -0.9625104541845879, - -1.498756714777078, -1.444909055525867, 1.4310462548243115, - 0.8220835498036487, 0.5057834515828189, -0.837961862727987, - 0.2605714955680978, -0.4099725867614261, 0.7070170534172308, - -0.6723725074576706, -0.1201197443461519, 0.1345110637167796}; + 0.006277522211, -0.001117962774, 0.000618580445, 0.009928999655, + 0.003026035654, -0.006941982227, 0.000667853212, -0.002449963843, + 0.006506463508, -0.007284129115, 0.000530662205, -0.000028806821, + 0.000068097781, 0.006121331983, -0.009019754602, -0.009658343745, + -0.006110103225, 0.008865499697}; std::vector expected_v = { - 0.3012732868676679, -0.111008111035248, 0.0607697438093234, - -0.1793710087436107, -0.4864743206172897, 0.198442096042078, - 0.1075875125278831, 0.202225050899223, -0.0292692216990346, - -0.8422437284528095, 0.8027921763701951, 0.6727780383521256, - -0.0498584648777359, 0.2840564878114856, 0.0022488426944385, - 0.8445217006092454, -0.7310051540006901, -0.6127740488262745, - -0.1346854853232727, -1.383173435157834, 0.5826625774172514, - -0.4798704193996531, -0.876873954037455, 0.5713457567298014, - 0.3936811857301074, 1.3042473872385445, -0.6164758064148346, - 0.2109867642597265, -0.0494829864962893, -0.118518922813471, - -0.0494765998087958, 0.319656735579515, -0.5152039050885358, - -0.1006642968068137, -0.5169139218651899, 0.8296380211652136, - 1.0631848697680133, -0.0866767331819718, 0.3115852842045283, - 0.5588465564114341, -0.2540276865795372, 0.4516708669974016, - -0.7221708332192804, 0.3830520421137477, -0.6759267741157751, - -1.1226313508171666, -0.0839482805543516, 0.0450669462011663, - -0.7117674336371344, 0.0994128124898613, -0.2345635147240297, - 1.0313883983297838, -0.1676652617344831, 0.3738750402753665}; + -0.000155238009, 0.000116605516, -0.007869862476, 0.000465578340, + 0.008182547185, -0.002398713212, -0.008112887338, -0.002423738425, + 0.007210716605, -0.019203504012, 0.001724938709, 0.009909211091, + 0.001153857542, -0.001600015103, -0.000560024090, 0.010727836276, + -0.001034836404, -0.007973454377, -0.021517399106, -0.004064359664, + 0.004866398692, -0.003360038617, -0.007241406162, 0.005920941051, + 0.004899151657, 0.006290788591, -0.006478820311, 0.001921504710, + 0.001313470921, -0.000304091236, 0.001684345981, 0.004124109256, + -0.006396084465, -0.000701095618, -0.006356507032, 0.009818550859, + -0.015230664587, -0.000110244376, 0.000690319396, 0.000045953023, + -0.005726548770, 0.008769818495, -0.000572380210, 0.008860603423, + -0.013819348050, -0.021227082558, -0.004977781343, 0.006646239696, + -0.005987066507, -0.002767831232, 0.003746502525, 0.007697590397, + 0.003746130152, -0.005172634748}; int natoms; double expected_tot_e; std::vector expected_tot_v;