Skip to content

Commit

Permalink
fix: update tf UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 14, 2024
1 parent 9f3b47d commit 71f24ff
Showing 1 changed file with 45 additions and 206 deletions.
251 changes: 45 additions & 206 deletions source/tests/tf/test_model_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,214 +123,53 @@ def test_model(self):
[pred_dos, pred_atom_dos] = sess.run([dos, atom_dos], feed_dict=feed_dict_test)

ref_dos = np.array(
[
-2.98834333,
-0.63166985,
-3.37199568,
-1.88397887,
0.87560992,
4.85426159,
-1.22677731,
-0.60918118,
8.80472675,
-1.12006829,
-3.72653765,
-3.03698828,
3.50906891,
5.55140795,
-3.34920924,
-4.43507641,
-6.1729281,
-8.34865917,
0.14371788,
-4.38078479,
-6.43141133,
4.07791938,
7.14102837,
-0.52347718,
0.82663796,
-1.64225631,
-4.63088421,
3.3910594,
-9.09682274,
1.61104204,
4.45900773,
-2.44688559,
-2.83298183,
-2.00733658,
7.33444256,
7.09187373,
-1.97065392,
0.01623084,
-7.48861264,
-1.17790161,
2.77126775,
-2.55552037,
3.3518257,
-0.09316856,
-1.94521413,
0.50089251,
-2.75763233,
-1.94382637,
1.30562041,
5.08351043,
-1.90604837,
-0.80030045,
-4.87093267,
4.18009666,
-2.9011435,
2.58497143,
4.47495176,
-0.9639419,
8.15692179,
0.48758731,
-0.62264663,
-1.70677258,
-5.51641378,
3.98621565,
0.57749944,
2.9658081,
-4.10467591,
-7.14827888,
0.02838605,
-2.48630333,
-4.82178216,
-0.7444178,
2.48224802,
-1.54683936,
0.46969412,
-0.0960347,
-2.08290541,
6.357031,
-3.49716615,
3.28959028,
7.83932727,
1.51457023,
-4.14575033,
0.02007839,
4.20953773,
3.66456664,
-4.67441496,
-0.13296372,
-3.77145766,
1.49368976,
-2.53627817,
-3.14188618,
0.24991722,
0.8770123,
0.16635733,
-3.15391098,
-3.7733242,
-2.25134676,
1.00975552,
1.38717682,
]
)
[ -1.98049388, -4.58033899, -6.95508968, -0.79619016,
15.58478599, 2.7636959 , -2.99147438, -6.94430794,
-1.77877141, -4.5000298 , -3.12026893, -8.42191319,
3.8991195 , 4.85271854, 8.30541908, -1.0435944 ,
-4.42713079, 19.70011955, -6.53945284, 0.85064846,
4.36868488, 4.77303801, 3.00829128, 0.70043584,
-7.69047143, -0.0647043 , 4.56830405, -8.67154404,
-4.64015279, -7.62202078, -8.97078455, -5.19685985,
-1.66080276, -6.03225716, -4.06780949, -0.53046979,
8.3543131 , -1.84893576, 2.42669245, -4.26357086,
-11.33995527, 10.98529887, -10.70000829, -4.50179402,
-1.34978505, -8.83091676, -11.85324773, -3.6305035 ,
2.89933807, 4.65750153, 1.25464578, -5.06196944,
10.05305042, -1.83868447, -11.57017913, -2.03900316,
-3.37235187, -1.37010554, -2.93769471, 0.11905709,
6.99367431, 3.48640865, -4.16242817, 4.44778342,
-0.98405367, 1.81581506, -5.31481686, 8.72426364,
4.78954098, 7.67879332, -5.00417706, 0.79717914,
-3.20581567, -2.96034568, 6.31165294, 2.9891188 ,
-12.2013139 , -13.67496037, 4.77102881, 2.71353286,
6.83849229, -3.50400312, 1.3839428 , -5.07550528,
-8.5623218 , 17.64081151, 6.46051807, 2.89067584,
14.23057359, 17.85941763, -6.46129295, -3.43602528,
-3.13520203, 4.45313732, -5.23012576, -2.65929557,
-0.66191939, 4.47530191, 9.33992973, -6.29808733])

ref_ados_1 = np.array(
[
-0.33019322,
-0.76332506,
-0.32665648,
-0.76601747,
-1.16441856,
-0.13627609,
-1.15916671,
-0.13280604,
2.60139518,
0.44470952,
-0.48316771,
-1.15926141,
2.59680457,
0.46049936,
-0.29459777,
-0.76433726,
-0.52091744,
-1.39903065,
-0.49890317,
-1.15747878,
0.66585524,
0.81804842,
1.38592217,
-0.18025826,
-0.2964021,
-0.74953328,
-0.7427461,
3.27935087,
-1.09340192,
0.1462458,
-0.51982728,
-1.40236941,
0.73902497,
0.79969456,
0.50726592,
0.11403234,
0.64964525,
0.8084967,
-1.27543102,
-0.00571457,
0.7748912,
-1.42492251,
1.38371838,
-0.17366078,
-0.76119888,
-1.26083707,
-1.48263244,
-0.85698727,
-0.7374573,
3.28274006,
-0.27029769,
-1.00478711,
-0.67481511,
-0.07978058,
-1.09001574,
0.14173437,
1.4092343,
-0.31785424,
0.40551362,
-0.71900495,
0.7269307,
0.79545851,
-1.88407155,
1.83983772,
-1.78413438,
-0.74852344,
0.50059876,
0.1165872,
-0.2139368,
-1.44989426,
-1.96651281,
-0.6031689,
-1.28106632,
-0.01107711,
0.48796663,
0.76500912,
0.21308153,
-0.85297893,
0.76139868,
-1.44547292,
1.68105021,
-0.30655702,
-1.93123,
-0.34294737,
-0.77352498,
-1.26982082,
-0.5562998,
-0.22048683,
-0.48641512,
0.01124872,
-1.49597963,
-0.86647985,
1.17310075,
0.59402879,
-0.705076,
0.72991794,
-0.27728806,
-1.00542829,
-0.16289102,
0.29464248,
]
)
[-0.33019322, -0.76332506, -1.15916671, -0.13280604, 2.59680457,
0.46049936, -0.49890317, -1.15747878, -0.2964021 , -0.74953328,
-0.51982728, -1.40236941, 0.64964525, 0.8084967 , 1.38371838,
-0.17366078, -0.7374573 , 3.28274006, -1.09001574, 0.14173437,
0.7269307 , 0.79545851, 0.50059876, 0.1165872 , -1.28106632,
-0.01107711, 0.76139868, -1.44547292, -0.77352498, -1.26982082,
-1.49597963, -0.86647985, -0.27728806, -1.00542829, -0.67794229,
-0.08898442, 1.39205396, -0.30789099, 0.40393006, -0.70982912,
-1.88961087, 1.830906 , -1.78326071, -0.75013615, -0.22537904,
-1.47257916, -1.9756803 , -0.60493323, 0.48350014, 0.77676571,
0.20885468, -0.84351691, 1.67501205, -0.30662021, -1.92884376,
-0.34021625, -0.56212664, -0.22884438, -0.4891038 , 0.0199886 ,
1.16506594, 0.58068956, -0.69376438, 0.74156043, -0.16360848,
0.30303168, -0.88639571, 1.453683 , 0.79818052, 1.2796414 ,
-0.8335433 , 0.13359098, -0.53425462, -0.4939294 , 1.05247266,
0.49770575, -2.03320073, -2.27918678, 0.79462598, 0.45187804,
1.13925239, -0.58410808, 0.23092918, -0.84611213, -1.42726499,
2.93985879, 1.07635712, 0.48092082, 2.37197063, 2.97647126,
-1.07670667, -0.57300341, -0.52316403, 0.74274268, -0.87188274,
-0.44279998, -0.11060956, 0.74619435, 1.55646754, -1.05043903])

places = 4
np.testing.assert_almost_equal(pred_dos, ref_dos, places)
Expand Down

0 comments on commit 71f24ff

Please sign in to comment.