diff --git a/.pytest.ini b/.pytest.ini index 2b7ebd6d..075b61b4 100644 --- a/.pytest.ini +++ b/.pytest.ini @@ -1,2 +1,8 @@ [pytest] addopts = -s -v --durations=10 +filterwarnings = + error + ignore::DeprecationWarning +markers = + ipu: marks tests specific to IPU (deselect with '-m "not ipu"') + serial diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index b50e3add..522ee5fb 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -138,7 +138,13 @@ def get_JK(density_matrix, ERI, dense_ERI, backend): return diff_JK -def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol): +def _nanoDFT(state, ERI, grid_AO, grid_weights, profile_performance, opts, mol): + + if profile_performance is not None and opts.backend == "ipu": + print("[INFO] Running nanoDFT with performance profiling.") + grid_weights, start = utils.get_ipu_cycles(grid_weights) + + # Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel. #if opts.backend == "ipu": state.ERI = electron_repulsion_integrals(state.input_floats, state.input_ints, mol, opts.threads_int, opts.intv) #else: pass # Compute on CPU. @@ -156,6 +162,10 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol): # Perform DFT iterations. log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [state.density_matrix, V_xc, diff_JK, state.O, H_core, state.L_inv, # all (N, N) matrices state.E_nuc, state.mask, ERI, grid_weights, grid_AO, state.diis_history, log])[-1] + if profile_performance is not None and opts.backend == "ipu": + log["energy"], end = utils.get_ipu_cycles(log["energy"]) + + return log["matrices"], H_core, log["energy"], (start.array, end.array) return log["matrices"], H_core, log["energy"] @@ -256,7 +266,7 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9): return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO -def nanoDFT(mol, opts): +def nanoDFT(mol, opts, profile_performance=None): # Init DFT tensors on CPU using PySCF. state, n_electrons_half, E_nuc, N, L_inv, _grid_weights, grid_coords, grid_AO = init_dft_tensors_cpu(mol, opts) @@ -314,17 +324,22 @@ def nanoDFT(mol, opts): ERI = [nonzero_distinct_ERI, nonzero_indices] eri_in_axes = [0,0] - #jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend) + # jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend) jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend, - in_axes=(None, eri_in_axes, 0, 0), + in_axes=(None, eri_in_axes, 0, 0, None), axis_name="p") - print(grid_AO.shape, grid_weights.shape) - vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights) - logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU + + vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights, profile_performance) + + if profile_performance is not None and opts.backend == "ipu": + logged_matrices, H_core, logged_energies, _ = [np.asarray(a[0]).astype(np.float64) for a in vals] + ipu_cycles_stamps = vals[3] + else: + logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU # It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices. logged_E_xc = logged_energies[:, 3].copy() - print(logged_energies[:, 0] * HARTREE_TO_EV) + # print(logged_energies[:, 0] * HARTREE_TO_EV) density_matrices, diff_JKs, H = [logged_matrices[:, i] for i in range(3)] energies, hlgaps = np.zeros((opts.its, 5)), np.zeros(opts.its) for i in range(opts.its): @@ -333,6 +348,10 @@ def nanoDFT(mol, opts): energies, logged_energies, hlgaps = [a * HARTREE_TO_EV for a in [energies, logged_energies, hlgaps]] mo_energy, mo_coeff = np.linalg.eigh(L_inv @ H[-1] @ L_inv.T) mo_coeff = L_inv.T @ mo_coeff + + if profile_performance is not None and opts.backend == "ipu": + return energies, (logged_energies, hlgaps, mo_energy, mo_coeff, grid_coords, _grid_weights), ipu_cycles_stamps + return energies, (logged_energies, hlgaps, mo_energy, mo_coeff, grid_coords, _grid_weights) def DIIS(i, H, density_matrix, O, diis_history, opts): @@ -610,8 +629,8 @@ def nanoDFT_options( from pyscf_ipu.experimental.device import has_ipu import os - if has_ipu() and "JAX_IPU_USE_MODEL" in os.environ: - args.dense_ERI = True + # if has_ipu() and "JAX_IPU_USE_MODEL" in os.environ: + # args.dense_ERI = True args = namedtuple('DFTOptionsImmutable',vars(args).keys())(**vars(args)) # make immutable if not args.float32: jax.config.update('jax_enable_x64', not float32) diff --git a/pyscf_ipu/nanoDFT/utils.py b/pyscf_ipu/nanoDFT/utils.py index f9e1197e..396b51af 100644 --- a/pyscf_ipu/nanoDFT/utils.py +++ b/pyscf_ipu/nanoDFT/utils.py @@ -132,6 +132,19 @@ def process_mol_str(mol_str: str): ["H", (1, 0, 0)], ["H", (1, 1, 1)] ] + elif mol_str == "bench1": mol_str = [['C', (-0.07551087, 1.68127663, -0.10745193)], ['O', (1.33621755, 1.87147409, -0.39326987)], ['C', (1.67074668, 2.95729545, 0.49387976)], ['C', (0.41740763, 3.77281969, 0.78495878)], ['C', (-0.6048148, 3.07572636, 0.28906224)], ['H', (-0.19316298, 1.01922455, 0.72486113)], ['O', (0.35092043, 5.03413298, 1.45545728)], ['H', (0.42961487, 5.74279041, 0.81264173)], ['O', (-1.9533175, 3.53349874, 0.15912025)], ['H', (-2.55333895, 2.78846397, 0.23972698)], ['O', (2.81976302, 3.20110148, 0.94542226)], ['C', (-0.81772499, 1.09230218, -1.32146482)], ['H', (-0.70955636, 1.74951833, -2.15888136)], ['C', (-2.31163857, 0.93420736, -0.98260166)], ['H', (-2.72575463, 1.89080093, -0.74107186)], ['H', (-2.41980721, 0.2769912, -0.14518512)], ['O', (-0.26428017, -0.18613595, -1.64425697)], ['H', (-0.7269591, -0.55328886, -2.40104423)], ['O', (-3.00083741, 0.38730252, -2.10989934)], ['H', (-3.93210821, 0.2887499, -1.89865997)]] + elif mol_str == "bench2": mol_str = [['C', (-0.84925689, 1.19426748, 0.0)], ['O', (0.58111411, 1.19426748, 0.0)], ['C', (1.08310511, 2.53365748, 0.0)], ['C', (-0.09344989, 3.50382448, -1e-06)], ['C', (-1.34480089, 2.63646648, -1e-06)], ['H', (-1.17341789, 0.63528848, 0.915132)], ['H', (1.71994411, 2.64115648, 0.91536)], ['H', (-0.06206089, 4.16301548, -0.901905)], ['H', (-1.97307989, 2.83841948, -0.901905)], ['O', (-0.05328805, 4.34730633, 1.1540464)], ['H', (-0.02922706, 5.26656976, 0.87840928)], ['O', (-2.14872719, 2.89488055, 1.15404608)], ['H', (-3.07612353, 2.89074011, 0.90601607)], ['C', (-1.29487091, 0.42585647, -1.25800344)], ['H', (-2.36057357, 0.33005726, -1.25860833)], ['H', (-0.98433839, 0.96058082, -2.13123848)], ['O', (-0.6994091, -0.87426857, -1.25830689)], ['H', (-1.15722348, -1.43626839, -1.88772107)], ['N', (1.95854804, 2.68143304, -1.25831716)], ['C', (3.40750745, 2.83869936, -1.2483172)], ['C', (1.52513756, 2.69033382, -2.59758791)], ['C', (3.83177447, 2.94266783, -2.62652798)], ['H', (0.49017157, 2.5912999, -2.9235557)], ['C', (5.16750068, 3.10114669, -2.93617186)], ['C', (5.68858554, 3.05790761, -0.55874621)], ['H', (6.43570855, 3.10532282, 0.24755786)], ['N', (2.63327508, 2.84671184, -3.4506498)], ['N', (4.33322006, 2.89644875, -0.2260484)], ['N', (6.09645965, 3.15797838, -1.88374496)], ['H', (6.74393404, 2.41331683, -2.04578224)], ['O', (5.6077051, 3.20620138, -4.29266866)]] + elif mol_str == "bench3": mol_str = [['C', (0.17302217, -0.0536787, 0.0206023)], ['C', (0.68636442, 0.67471439, 1.27659722)], ['H', (0.33130807, 1.68408369, 1.27366248)], ['H', (0.32809533, 0.17314097, 2.15122134)], ['H', (1.75636261, 0.67300961, 1.27757926)], ['C', (0.68636436, 0.66983803, -1.23820799)], ['H', (1.75636429, 0.67015162, -1.23801986)], ['H', (0.32999896, 0.16352848, -2.1108788)], ['H', (0.32940246, 1.678537, -1.24035286)], ['C', (-1.36697783, -0.05365972, 0.0206023)], ['C', (-2.06451416, -1.26192666, 0.02294888)], ['C', (-2.06446856, 1.15425163, 0.01836764)], ['C', (-3.45922608, -1.26218837, 0.02237872)], ['H', (-1.51433285, -2.2140513, 0.02393307)], ['C', (-3.45960644, 1.15416501, 0.01879292)], ['H', (-1.51485954, 2.10664659, 0.0166979)], ['C', (-4.15704263, -0.05377385, 0.02065983)], ['H', (-4.00910069, -2.21451865, 0.02359421)], ['H', (-4.0092797, 2.10670462, 0.01744603)], ['C', (0.68633788, -1.50560812, 0.02341767)], ['C', (0.91946288, -2.16536088, -1.18358293)], ['C', (0.91831096, -2.16106624, 1.23259561)], ['C', (1.38381298, -3.48050177, -1.18131091)], ['H', (0.7357492, -1.6485928, -2.13670508)], ['C', (1.38374179, -3.4762765, 1.23504147)], ['H', (0.73479448, -1.64116737, 2.18399242)], ['C', (1.6163816, -4.13609396, 0.02836974)], ['H', (1.5669898, -4.00080229, -2.13264163)], ['H', (1.56694472, -3.99268549, 2.1885784)], ['O', (2.09255288, -5.48448444, 0.03037417)], ['H', (3.05255142, -5.48531759, 0.03182891)], ['O', (-5.58704242, -0.05436711, 0.02015267)], ['H', (-5.9079705, 0.43788494, 0.77929237)]] + elif mol_str == "bench4": mol_str = [['C', (-1.41478807, -1.38673987, -0.27546855)], ['C', (0.10179343, -1.53458053, -0.07889229)], ['C', (0.53408642, -0.37025944, 0.4514519)], ['C', (-0.67562941, 0.57796327, 0.46559768)], ['H', (0.6950762, -2.39648436, -0.30089545)], ['H', (1.52737697, -0.14599081, 0.78086417)], ['C', (-0.6237082, 1.96362793, 0.78635139)], ['H', (0.21323483, 2.35384028, 1.32735933)], ['C', (-2.29970743, -2.44324831, -0.6370283)], ['H', (-1.91393373, -3.32969289, -1.09683478)], ['C', (-1.64052102, 2.84486117, 0.32186518)], ['C', (-3.52932141, 3.58449372, -0.59511281)], ['C', (-2.59214132, 4.79754521, -0.48499553)], ['H', (-2.78843501, 5.79301664, -0.82737225)], ['C', (-1.48137242, 4.36442701, 0.14972837)], ['N', (-2.86612663, 2.50908907, -0.12553285)], ['N', (-1.74741954, -0.12018577, 0.0412691)], ['H', (-0.64679525, 4.9592148, 0.45645615)], ['C', (-3.68393396, -2.35600549, -0.31223368)], ['C', (-4.63005393, -3.55629325, -0.15598101)], ['N', (-4.37818634, -1.2430198, -0.00220836)], ['C', (-5.78557755, -3.06857539, 0.34505536)], ['H', (-4.41027665, -4.57910031, -0.38223928)], ['C', (-5.63589939, -1.53901637, 0.37901842)], ['H', (-6.64304239, -3.63441062, 0.64304078)], ['C', (-6.68305631, -0.62113443, 0.67405446)], ['H', (-7.56136499, -0.96043354, 1.18328735)], ['C', (-6.6012237, 0.72696467, 0.22471451)], ['C', (-7.80292122, 1.66378985, 0.0226908)], ['N', (-5.49689043, 1.38542695, -0.17958107)], ['C', (-7.32644297, 2.77323824, -0.58189005)], ['H', (-8.81986259, 1.46790799, 0.29231039)], ['C', (-5.79921052, 2.61414176, -0.64211669)], ['H', (-7.8979571, 3.60669127, -0.93279659)], ['C', (-4.88433877, 3.63473378, -1.03220086)], ['H', (-5.22991618, 4.47158073, -1.60163868)], ['Mg', (-3.6204912, 0.63120563, -0.10251958)]] + elif mol_str == "bench5": mol_str = [['C', (-0.77452336, -0.08132633, 1.74172506)], ['C', (-0.26411504, -0.11129929, 0.4369236)], ['C', (-0.3539735, 1.02601929, -0.37692764)], ['C', (-0.95424489, 2.19330919, 0.11402082)], ['C', (-1.46465579, 2.22328125, 1.41882131)], ['C', (-1.37479443, 1.0859637, 2.23267366)], ['H', (-0.70591476, -0.94969444, 2.3631184)], ['H', (0.03573621, 1.00313485, -1.37317189)], ['H', (-1.02285582, 3.06167644, -0.50737346)], ['H', (-1.92297618, 3.11453305, 1.79367143)], ['H', (-1.76450608, 1.10884747, 3.22891718)], ['O', (0.34840202, -1.30241314, -0.06404605)], ['C', (1.35270587, -0.9556787, -1.02115052)], ['H', (2.09459791, -0.34266532, -0.55348523)], ['H', (0.90396901, -0.41799437, -1.8301195)], ['C', (2.01233961, -2.23841673, -1.5606563)], ['O', (2.09302887, -3.2596824, -0.82985109)], ['N', (2.54773286, -2.26985691, -2.92932924)], ['H', (3.50527213, -1.98172537, -2.91939174)], ['C', (2.45289144, -3.63693794, -3.46130165)], ['C', (3.06213638, -3.9815682, -4.81429695)], ['C', (3.2077577, -4.83278916, -2.92168779)], ['H', (1.41330079, -3.88850264, -3.43186598)], ['H', (2.2125826, -3.95040754, -5.46405646)], ['O', (3.42944279, -5.30612203, -1.7769685)], ['N', (3.67522246, -5.14821806, -4.31096996)], ['C', (3.66579173, -6.10484804, -5.36714016)], ['H', (4.49655524, -6.75828572, -5.20055325)], ['C', (3.84517881, -5.06783193, -6.57540908)], ['C', (4.72191908, -5.6417924, -7.70390186)], ['H', (4.27386857, -6.53609107, -8.08387642)], ['H', (4.80563701, -4.92235784, -8.49149645)], ['H', (5.69541534, -5.86571901, -7.32041536)], ['C', (2.45462245, -4.72770184, -7.14309246)], ['H', (2.00895589, -5.61186346, -7.54872716)], ['H', (1.83452668, -4.34238508, -6.36084427)], ['H', (2.55421902, -3.99253283, -7.91413529)], ['S', (4.42031248, -3.49330876, -5.80985222)], ['C', (2.41628932, -6.97906376, -5.58182741)], ['O', (1.24772273, -6.54739905, -5.20566321)], ['O', (2.52902641, -8.14944067, -6.13940253)], ['H', (3.39110571, -8.4678895, -6.41690875)]] + elif mol_str == "bench6": mol_str = [['C', (-2.81999319, 2.42808848, 0.54098799)], ['C', (-1.48506359, 2.12574154, 0.24053959)], ['C', (-0.57560888, 3.15808633, -0.02458108)], ['C', (-1.00054575, 4.49301108, 0.01119766)], ['C', (-2.33541234, 4.79564647, 0.31151005)], ['C', (-3.24524285, 3.76302115, 0.57613058)], ['H', (-3.51427777, 1.63974701, 0.74369099)], ['H', (0.44352769, 2.92714554, -0.25417935)], ['H', (-0.30609255, 5.28106112, -0.1905295)], ['H', (-2.66014056, 5.81490526, 0.33906034)], ['H', (-4.26458005, 3.99396043, 0.80476437)], ['C', (-1.01732442, 0.65950815, 0.2018529)], ['H', (-0.67533111, 0.36853591, 1.17285734)], ['H', (-1.83249438, 0.03203052, -0.09154374)], ['C', (0.13235379, 0.51505962, -0.81252046)], ['H', (-0.21019462, 0.80566498, -1.78314105)], ['C', (1.30596367, 1.41901641, -0.39114863)], ['C', (1.67510534, -1.02185697, -1.82127087)], ['N', (0.57986576, -0.88420978, -0.85050744)], ['H', (-0.18284568, -1.47067933, -1.12333537)], ['C', (2.75580261, -2.09988547, -1.61482923)], ['C', (2.69427195, -2.95177434, -0.50654721)], ['C', (3.7995502, -2.23324053, -2.5379049)], ['C', (3.70091458, -3.90483765, -0.3018299)], ['H', (1.88147807, -2.87383893, 0.18522781)], ['C', (4.78600353, -3.21184574, -2.34716707)], ['H', (3.84167861, -1.59151423, -3.39308111)], ['C', (4.77253365, -4.00413187, -1.20130033)], ['C', (5.85248621, -3.4258967, -3.42542028)], ['C', (5.90561132, -5.00024671, -0.90103115)], ['C', (6.7765891, -5.27533663, -2.15038566)], ['H', (6.54653968, -4.59705454, -0.14509272)], ['O', (2.17311788, 1.97320056, -1.38357913)], ['H', (2.5148354, 2.81560975, -1.07582854)], ['O', (1.5016239, 1.66968264, 0.82605442)], ['O', (1.72973935, -0.25904981, -2.82056879)], ['C', (7.73469993, -6.47923452, -2.09043925)], ['H', (7.25757496, -7.33559502, -2.51759025)], ['H', (8.62396559, -6.25453714, -2.64131193)], ['H', (7.98803284, -6.68347394, -1.0713362)], ['Cl', (3.62524241, -4.97417462, 1.09396391)], ['O', (5.79027817, -2.75121797, -4.48594121)], ['O', (6.90897033, -4.37928239, -3.27178983)], ['H', (5.46196268, -5.90980388, -0.55344219)], ['H', (7.47765424, -4.60305065, -1.70382804)]] + elif mol_str == "bench7": mol_str = [['C', (-1.31610737, 2.63679048, -0.00084644)], ['C', (0.07905007, 2.63934168, -4.682e-05)], ['C', (0.77437809, 3.84836613, 0.00070992)], ['C', (0.07451596, 5.05559768, -0.00053176)], ['C', (-1.32030606, 5.05296923, -0.00181022)], ['C', (-2.01569681, 3.84348937, -0.00157112)], ['H', (-1.86412404, 1.68346969, -0.00099299)], ['H', (1.87405556, 3.85045682, 0.00197421)], ['H', (0.62297373, 6.00874516, 5.98e-06)], ['H', (-1.87216801, 6.00424292, -0.00279707)], ['H', (-3.11529902, 3.84166168, -0.0023813)], ['C', (0.85103942, 1.30681443, 0.00184158)], ['H', (1.79337581, 1.4398572, -0.48724099)], ['C', (0.41505987, -0.48736387, 1.68199576)], ['C', (2.57473964, 0.77351617, 1.72731358)], ['C', (0.69110348, -1.01300628, 3.07592768)], ['H', (0.78959033, -1.22526334, 0.92463509)], ['H', (-0.691016, -0.38321987, 1.53238838)], ['C', (2.85074861, 0.24863316, 3.12155937)], ['H', (3.05729485, 0.09861578, 0.97233225)], ['H', (3.03204876, 1.79032016, 1.611471)], ['H', (0.2349028, -2.03043373, 3.19085261)], ['H', (0.20695056, -0.33927459, 3.83095099)], ['H', (3.95676021, 0.14403431, 3.27103621)], ['H', (2.47681583, 0.98744831, 3.87839265)], ['N', (1.08617979, 0.85184124, 1.45418741)], ['N', (2.17897791, -1.08989652, 3.35069584)], ['C', (0.03162829, 0.23777883, -0.74470266)], ['C', (0.56867017, -0.3958751, -1.86565767)], ['C', (-1.24736307, -0.09709064, -0.30018981)], ['C', (-0.17344898, -1.36368445, -2.54228606)], ['H', (1.57672684, -0.13109474, -2.21630694)], ['C', (-1.98945087, -1.06588037, -0.9763274)], ['H', (-1.67071899, 0.40210131, 0.58338598)], ['C', (-1.45278973, -1.69908441, -2.09730218)], ['H', (0.24953934, -1.86273063, -3.42621495)], ['H', (-2.99769071, -1.33000002, -0.62537518)], ['Cl', (-2.38893911, -2.92048839, -2.95134941)], ['C', (2.78094574, -2.08721982, 2.45410174)], ['H', (2.02278564, -2.75673563, 2.10504323)], ['H', (3.23055355, -1.59117358, 1.61942159)], ['C', (3.85566305, -2.8823677, 3.21851166)], ['H', (3.99860784, -3.83253439, 2.74771441)], ['H', (4.77651636, -2.33753632, 3.2091328)], ['C', (4.2680292, -4.18276315, 5.33113825)], ['H', (3.99105398, -4.28769779, 6.35932766)], ['H', (5.29898799, -3.90383432, 5.2661561)], ['C', (4.05157947, -5.52181188, 4.60196201)], ['O', (4.49897618, -6.59036968, 5.09348105)], ['O', (3.33974189, -5.55104438, 3.36206979)], ['H', (3.28517852, -6.45557385, 3.04513244)], ['O', (3.39930599, -3.09413546, 4.67401592)]] + elif mol_str == "bench8": mol_str = [['C', (-1.42666665, 1.35988349, 0.01780185)], ['C', (-0.75139234, 2.53486079, 0.01780185)], ['C', (-2.96666665, 1.35988349, 0.01780185)], ['C', (-3.66418809, 0.15160568, 0.01780185)], ['C', (-3.66417225, 2.56778831, 0.01791304)], ['C', (-5.05890001, 0.15132789, 0.01723115)], ['H', (-3.11399504, -0.8005123, 0.01693694)], ['C', (-5.05931013, 2.56768367, 0.01833813)], ['H', (-3.11457497, 3.52019148, 0.01809296)], ['C', (-5.75673144, 1.35973487, 0.01785909)], ['H', (-5.60876287, -0.80100973, 0.01659711)], ['H', (-5.60899513, 3.52021733, 0.01884114)], ['H', (-6.85641138, 1.35926586, 0.01746817)], ['C', (-1.51874951, 3.87006226, 0.01780185)], ['C', (-1.63823871, 4.60590036, -1.16149287)], ['C', (-2.09440347, 4.34371845, 1.19670832)], ['C', (-2.3326658, 5.81544273, -1.16163975)], ['H', (-1.18363273, 4.23258432, -2.090584)], ['C', (-2.78991814, 5.55312706, 1.19651365)], ['H', (-2.00047584, 3.76380313, 2.12622693)], ['C', (-2.90901419, 6.28907563, 0.01764434)], ['H', (-2.42635385, 6.39580205, -2.09099551)], ['H', (-3.2440432, 5.92613353, 2.12608927)], ['C', (0.78860766, 2.53486079, 0.01780185)], ['C', (1.4861291, 3.74313859, 0.01780185)], ['C', (1.48611327, 1.32695597, 0.01791304)], ['C', (2.88084102, 3.74341639, 0.01723115)], ['H', (0.93593606, 4.69525658, 0.01693694)], ['C', (2.88125115, 1.3270606, 0.01833813)], ['H', (0.93651599, 0.37455279, 0.01809296)], ['C', (3.57867246, 2.5350094, 0.01785909)], ['H', (3.43070389, 4.695754, 0.01659711)], ['H', (3.43093615, 0.37452694, 0.01884114)], ['H', (4.6783524, 2.53547842, 0.01746817)], ['C', (-0.65930948, 0.02468201, 0.01780185)], ['H', (-0.04466478, -0.03344716, -0.85611628)], ['H', (-0.04386363, -0.03298673, 0.89118649)], ['C', (-1.66236338, -1.14385651, 0.01856968)], ['H', (-2.27713573, -1.08561745, 0.89239069)], ['H', (-2.27768159, -1.08629703, -0.8549121)], ['H', (-1.12919956, -2.07156136, 0.01876393)], ['O', (-3.62101473, 7.52921876, 0.01715974)], ['C', (-2.69982994, 8.60858726, 0.19402752)], ['H', (-2.03011871, 8.64615667, -0.63962434)], ['H', (-2.14108178, 8.456809, 1.09384076)], ['C', (-3.47584819, 9.93535894, 0.28927757)], ['H', (-4.0545645, 10.07469158, -0.59986462)], ['H', (-4.1269469, 9.90759901, 1.13792346)], ['C', (-1.65137806, 10.90285045, 1.72438609)], ['H', (-2.24764703, 10.40869908, 2.46274761)], ['H', (-0.7911044, 10.306338, 1.50302183)], ['H', (-1.33836538, 11.85545774, 2.09783276)], ['C', (-3.25771829, 12.42866058, 0.53449492)], ['H', (-2.5661118, 13.24181825, 0.60767325)], ['H', (-3.86037095, 12.55070987, -0.3411841)], ['H', (-3.88574784, 12.41553739, 1.40069735)], ['N', (-2.48185199, 11.10154878, 0.44281205)]] + elif mol_str == "bench9": mol_str = [['C', (-1.42958937, -1.01658249, 2.44831265)], ['C', (-1.54149463, -0.27943595, 1.12935913)], ['C', (-2.95243912, 1.47129501, 2.2249255)], ['C', (-2.84134324, 0.7340028, 3.5438881)], ['C', (-1.55797242, -0.06742235, 3.62210684)], ['H', (-0.6649782, 0.41085733, 1.01283439)], ['H', (-0.44314689, -1.54644194, 2.49994231)], ['H', (-3.93813646, 2.00251835, 2.17350554)], ['H', (-3.71865156, 0.04459988, 3.66018043)], ['H', (-0.68133448, 0.63238265, 3.63327796)], ['O', (-1.92499204, 2.46279929, 2.14636482)], ['C', (-1.46076301, 4.67520648, 3.08258818)], ['C', (-1.65192906, 4.60593616, 4.58432062)], ['H', (-0.74117334, 5.43388332, 2.85565168)], ['C', (-3.37658473, 6.27667868, 2.91318631)], ['C', (-2.26541279, 5.88367074, 5.11970508)], ['H', (-2.3174731, 3.73866848, 4.83580579)], ['C', (-3.5663604, 6.20921206, 4.41472333)], ['H', (-4.36600055, 6.45426127, 2.41719619)], ['H', (-1.54350069, 6.73075622, 4.97913735)], ['H', (-4.32958469, 5.42429404, 4.65913754)], ['O', (-2.52952388, 7.38269576, 2.59051909)], ['C', (-3.31900286, 8.44866624, 2.0563563)], ['C', (-3.26458455, 9.66301263, 2.97202724)], ['C', (-2.80523815, 10.85434814, 2.14390783)], ['H', (-4.26911658, 9.86120175, 3.4226732)], ['C', (-2.45049436, 10.34868118, 0.7529107)], ['H', (-1.92149828, 11.34764146, 2.62047963)], ['H', (-3.01626404, 10.91739387, -0.02697788)], ['C', (-4.76356005, 7.92754855, 1.94107026)], ['H', (-4.78266313, 7.06714393, 1.30527236)], ['H', (-5.12531018, 7.66145994, 2.91227245)], ['O', (-2.82580898, 0.520717, 1.05150767)], ['C', (-0.93849427, 3.3173276, 2.57761263)], ['H', (-0.77727115, 3.36776157, 1.52103158)], ['H', (-0.01653199, 3.08491476, 3.06839754)], ['C', (-1.49224313, -1.28985314, -0.03177454)], ['H', (-0.58145107, -1.84899344, 0.02037852)], ['H', (-2.32526415, -1.95746693, 0.04080516)], ['C', (-0.95163774, 10.56175833, 0.4707282)], ['H', (-0.66603955, 9.99216789, -0.38886496)], ['H', (-0.38034325, 10.24135666, 1.31681648)], ['O', (-1.55044662, -0.58860344, -1.2766686)], ['H', (-1.40118565, -1.2038562, -1.99832308)], ['O', (-2.45542951, -2.00981905, 2.525986)], ['H', (-2.07814582, -2.83668451, 2.83509848)], ['O', (-1.5306674, -0.81010098, 4.84381994)], ['H', (-0.61998842, -0.9842467, 5.0926915)], ['O', (-2.88646987, 1.67226139, 4.622097)], ['H', (-3.42291955, 1.31644411, 5.33428785)], ['O', (-0.39050984, 4.37790095, 5.21813565)], ['H', (-0.43907755, 4.65272412, 6.13667438)], ['O', (-2.5020679, 5.74947404, 6.52358741)], ['H', (-2.24542728, 6.5593166, 6.97067871)], ['O', (-4.06906224, 7.4604827, 4.89065678)], ['H', (-4.76357327, 7.30283533, 5.53439721)], ['O', (-2.76235824, 4.99928466, 2.37789525)], ['O', (-5.59594331, 8.9491541, 1.38576835)], ['H', (-6.48998576, 8.61436213, 1.28476127)], ['O', (-2.79571817, 8.86766013, 0.68976834)], ['O', (-2.35757225, 9.42976003, 4.05268533)], ['H', (-2.2282828, 10.24429373, 4.54402949)], ['O', (-3.84292711, 11.83519815, 2.06622328)], ['H', (-3.45759742, 12.71343457, 2.10890596)], ['O', (-0.70503591, 11.94904429, 0.22675035)], ['H', (0.04258943, 12.04182451, -0.36826913)]] + elif mol_str == "bench10": mol_str = [['C', (-7.49315956, -1.43773119, -5.92299552)], ['H', (-7.94182211, -2.40858579, -5.95529496)], ['H', (-7.95636393, -0.85617992, -5.15348929)], ['H', (-7.62668554, -0.95153649, -6.86675689)], ['C', (-5.79642037, -2.27443704, -4.26611112)], ['H', (-6.24577738, -3.24503611, -4.29637197)], ['H', (-4.75102405, -2.37166315, -4.05972888)], ['H', (-6.25894988, -1.69106781, -3.49757564)], ['C', (-5.32241415, -2.41652735, -6.73175378)], ['H', (-5.4561432, -1.93219937, -7.67644576)], ['H', (-4.27697469, -2.51320533, -6.52533261)], ['H', (-5.77127293, -3.38735843, -6.7619667)], ['C', (-5.34230471, -0.18003416, -5.58198519)], ['H', (-5.80528977, 0.40321281, -4.81363123)], ['H', (-5.47564419, 0.30434483, -6.52670609)], ['C', (-3.83777346, -0.31946968, -5.28442266)], ['H', (-3.37486869, -0.90308305, -6.05254674)], ['H', (-3.70443723, -0.80345768, -4.33950093)], ['O', (-3.19138412, 1.07764258, -5.24131661)], ['O', (-1.08666761, 2.23553569, -4.90355336)], ['O', (-1.06813532, 0.15822874, -5.97009968)], ['O', (-1.50879854, 0.29141327, -3.68074512)], ['N', (-5.98857475, -1.57718235, -5.62571128)], ['C', (0.31042939, 2.10678634, -4.62706904)], ['H', (0.77368465, 1.52318382, -5.39499004)], ['H', (0.44389576, 1.62311344, -3.68200436)], ['C', (0.95616173, 3.50421186, -4.58427124)], ['H', (0.49289614, 4.08781851, -3.81635961)], ['N', (0.7728183, 4.16868977, -5.88263725)], ['C', (1.54573854, 5.24091768, -6.23123101)], ['H', (0.07949519, 3.81992495, -6.49108736)], ['O', (2.42345554, 5.69243391, -5.47559455)], ['C', (1.24369185, 5.83589104, -7.61916994)], ['H', (1.38780673, 5.08532476, -8.36802317)], ['H', (0.23032024, 6.17818782, -7.6476461)], ['H', (1.90308508, 6.65755098, -7.80618734)], ['C', (2.46072393, 3.36555987, -4.28649918)], ['H', (2.81093231, 4.25132556, -3.79898177)], ['H', (2.62197169, 2.52066581, -3.65006128)], ['C', (3.22917105, 3.16808327, -5.60638252)], ['H', (2.71745824, 3.26576104, -6.54099958)], ['C', (4.55128586, 2.87101548, -5.5882109)], ['H', (5.06299867, 2.77333771, -4.65359384)], ['P', (-1.68688011, 0.93821529, -4.94361253)], ['C', (5.31973298, 2.67353888, -6.90809424)], ['H', (5.50187069, 3.62584722, -7.36068089)], ['H', (4.73852105, 2.06839074, -7.57208852)], ['C', (6.66254822, 1.97641122, -6.62093532)], ['H', (7.23864011, 2.57569669, -5.94723283)], ['H', (6.48011089, 1.01887673, -6.17963685)], ['C', (7.43879459, 1.79489451, -7.93854423)], ['H', (7.60827107, 2.75133158, -8.38732686)], ['H', (6.86923187, 1.18328757, -8.60670234)], ['C', (8.79086565, 1.11708026, -7.64859306)], ['H', (9.33478428, 1.0007021, -8.56265469)], ['H', (8.62122899, 0.15610604, -7.20967171)], ['C', (9.60424993, 1.98984817, -6.67479524)], ['H', (9.31147739, 1.77019553, -5.66934163)], ['H', (9.41944065, 3.02310864, -6.88244478)], ['C', (11.10499458, 1.6916288, -6.84923923)], ['H', (11.41365563, 1.97843059, -7.83279049)], ['H', (11.27894526, 0.64483722, -6.71187639)], ['C', (11.91113852, 2.48722658, -5.80580446)], ['H', (11.72501028, 3.53319241, -5.93312987)], ['H', (11.61354107, 2.18914178, -4.82219823)], ['C', (13.41348542, 2.20643447, -5.99477425)], ['H', (13.59898733, 1.15986871, -5.87152995)], ['H', (13.71197645, 2.50833327, -6.97694542)], ['C', (14.21924372, 2.99740615, -4.94753151)], ['H', (14.03182008, 4.04381613, -5.0691808)], ['H', (13.92251458, 2.69373112, -3.96537404)], ['C', (15.72181966, 2.71939448, -5.13877756)], ['H', (15.90913955, 1.67288842, -5.01779686)], ['H', (16.01869462, 3.02369321, -6.12069789)], ['C', (16.5275164, 3.50960658, -4.09091421)], ['H', (16.33925543, 4.55603385, -4.21111315)], ['H', (16.23150563, 3.20443884, -3.10900267)], ['C', (18.03020202, 3.23295881, -4.28327314)], ['H', (18.21842532, 2.18649741, -4.1633126)], ['H', (18.32626477, 3.53834869, -5.26509994)], ['C', (18.83587717, 4.02290009, -3.23518901)], ['H', (19.87999516, 3.83124105, -3.369295)], ['H', (18.64710066, 5.06931441, -3.35469007)], ['H', (18.54032269, 3.71699987, -2.25336795)]] + elif mol_str == "bench11": mol_str = [['C', (0.240284, -0.968546, 0.057358)], ['C', (1.499558, -0.389994, 0.799765)], ['C', (1.844059, 1.113099, 0.526127)], ['C', (0.611152, 2.069949, 0.410275)], ['C', (-0.387189, 1.449098, -0.582889)], ['C', (-0.811981, 0.113677, 0.014032)], ['H', (1.344645, -0.483368, 1.89667)], ['H', (0.908155, 3.094741, 0.109552)], ['H', (0.071465, 1.400302, -1.594573)], ['H', (-1.08538, 0.339368, 1.098414)], ['O', (-0.032343, 2.140515, 1.697564)], ['H', (0.438322, 2.767394, 2.276379)], ['O', (-1.643456, 2.155986, -0.775276)], ['C', (-2.749358, 1.179186, -0.753555)], ['H', (-3.337709, 1.418582, 0.14457)], ['H', (-3.318202, 1.397448, -1.666498)], ['C', (-2.110589, -0.2299, -0.719944)], ['C', (2.729982, 1.327484, -0.704832)], ['H', (2.813168, 2.384445, -0.977584)], ['H', (3.749604, 0.958567, -0.53283)], ['H', (2.352007, 0.78104, -1.58051)], ['C', (2.6014, -1.343864, 0.30659)], ['C', (0.846782, -1.406136, -1.29617)], ['H', (0.882748, -0.593196, -2.039512)], ['H', (0.388152, -2.301374, -1.740346)], ['O', (2.225476, -1.781686, -1.029468)], ['C', (-0.422908, -2.193631, 0.752774)], ['H', (-0.320129, -3.083535, 0.102361)], ['C', (-1.914007, -2.007635, 1.112375)], ['H', (-2.334209, -2.995278, 1.383792)], ['H', (-1.980931, -1.388666, 2.031062)], ['C', (-2.813538, -1.370551, 0.027198)], ['H', (-3.1202, -2.147139, -0.69849)], ['C', (-1.82661295, -0.68751599, -2.16270012)], ['O', (-1.03585236, -0.24261727, -2.99355789)], ['O', (-2.59156054, -1.74766325, -2.52650357)], ['C', (-2.29916153, -2.14198817, -3.86960099)], ['H', (-2.96290254, -2.9282896, -4.16299137)], ['H', (-2.42740743, -1.30452633, -4.52313804)], ['H', (-1.28838658, -2.48820275, -3.92764814)], ['O', (-4.01986539, -0.90962471, 0.64138134)], ['C', (-4.89301012, -1.93494775, 0.80793745)], ['O', (-4.541531, -3.05110585, 0.4281805)], ['C', (-6.20834727, -1.48087047, 1.46771166)], ['H', (-6.70958996, -0.78829922, 0.82428269)], ['H', (-6.83594045, -2.33131805, 1.63434212)], ['H', (-5.99341406, -1.00749899, 2.40292455)], ['O', (0.29104226, -2.52037085, 1.94793763)], ['C', (0.31248536, -3.86361432, 2.13937213)], ['O', (-0.25336168, -4.56806573, 1.30443072)], ['C', (1.07328546, -4.25938123, 3.41849362)], ['C', (1.18469713, -5.56341278, 3.77014145)], ['H', (0.75137836, -6.32562659, 3.15681858)], ['C', (1.70966559, -3.16955354, 4.30104443)], ['H', (2.52793619, -2.72004059, 3.77829081)], ['H', (0.97813456, -2.42251044, 4.52839648)], ['H', (2.06508607, -3.60889199, 5.20964665)], ['C', (1.93754031, -5.94957419, 5.05688405)], ['H', (1.46239499, -5.49165555, 5.89917107)], ['H', (1.92238977, -7.01309886, 5.1734419)], ['H', (2.95091533, -5.61227499, 4.99207421)], ['C', (3.99823568, -0.71610148, 0.14421916)], ['O', (4.54063921, 0.18499764, 0.78248292)], ['O', (4.6998428, -1.27738694, -0.87269582)], ['O', (2.69271189, -2.53050618, 1.09933364)], ['H', (3.60067733, -2.84219679, 1.1062423)], ['C', (5.98847134, -0.6688573, -0.99113633)], ['H', (6.49970371, -0.7307557, -0.05320774)], ['H', (6.55618159, -1.17887968, -1.74112449)], ['H', (5.87374685, 0.35839671, -1.26770006)], ['C', (2.63486992, 1.58151749, 1.76176538)], ['C', (2.13434327, 2.21842175, 3.11643757)], ['C', (3.90461234, 2.4538709, 1.74128354)], ['O', (2.44467967, 0.78466796, 2.96396625)], ['C', (3.35337126, 2.9870945, 3.792439)], ['C', (0.74513758, 2.60743687, 3.44136489)], ['O', (5.00327683, 3.1919637, 1.11718214)], ['C', (4.47769203, 2.16352749, 3.16877423)], ['H', (3.15573566, 3.35599353, 1.51547111)], ['C', (3.84794511, 4.41584726, 3.25643717)], ['H', (3.24116904, 2.9988907, 4.88162906)], ['H', (0.00697023, 1.93995296, 2.97068106)], ['H', (0.55491721, 2.57388288, 4.52449549)], ['H', (0.54134467, 3.63458255, 3.09753074)], ['C', (4.84981258, 4.42246076, 1.92099071)], ['H', (4.49637929, 1.09030804, 3.43212004)], ['H', (5.51163803, 2.50195502, 3.3248999)], ['C', (4.76579887, 5.04464694, 4.26535741)], ['O', (2.75093022, 5.20578033, 2.83652107)], ['H', (4.60685318, 5.22136931, 1.20459035)], ['O', (6.17282363, 4.70855901, 2.47193815)], ['H', (4.42807865, 5.31783232, 5.24674785)], ['C', (6.01838353, 5.12565144, 3.78006571)], ['H', (2.50011685, 5.87405238, 3.50751412)], ['H', (6.95619123, 5.44887224, 4.20308201)]] + elif mol_str == "bench12": mol_str = [['C', (-3.784, 3.918, 3.085)], ['O', (-4.613, 4.803, 2.942)], ['C', (6.713, 1.233, 4.896)], ['C', (-5.602, 2.294, 3.721)], ['C', (-6.05, 1.052, 4.183)], ['C', (-5.13, 0.052, 4.505)], ['C', (-3.764, 0.293, 4.345)], ['C', (-3.323, 1.537, 3.885)], ['C', (-4.234, 2.553, 3.573)], ['C', (5.675, 0.948, 3.846)], ['C', (2.371, 0.849, 2.629)], ['C', (3.105, -0.126, 3.216)], ['C', (1.031, 1.222, 3.239)], ['C', (0.288, -0.094, 3.595)], ['C', (1.167, -1.284, 4.127)], ['C', (2.608, -0.785, 4.511)], ['C', (3.435, -1.97, 5.134)], ['C', (2.606, 0.312, 5.643)], ['C', (2.732, 1.622, 1.37)], ['C', (1.103, -2.5, 3.107)], ['C', (1.799, -2.324, 1.7)], ['C', (3.308, -2.771, 1.513)], ['C', (3.495, -4.317, 1.695)], ['C', (4.372, -2.133, 2.414)], ['O', (5.242, -2.814, 2.938)], ['C', (4.419, -0.643, 2.663)], ['C', (3.725, -2.341, 0.068)], ['C', (2.997, -3.146, -1.039)], ['C', (1.461, -3.215, -0.831)], ['C', (0.848, -2.793, 0.536)], ['C', (0.295, -4.236, 0.673)], ['O', (1.018, -4.542, -0.52)], ['O', (-0.291, -1.918, 0.364)], ['O', (0.129, 1.936, 2.354)], ['C', (0.335, 3.258, 2.231)], ['O', (1.171, 3.887, 2.86)], ['C', (-0.569, 3.953, 1.24)], ['O', (-0.325, 5.368, 1.252)], ['C', (-2.072, 3.652, 1.503)], ['C', (-2.914, 4.255, 0.365)], ['N', (-2.518, 4.225, 2.794)], ['C', (-0.027, -0.744, -0.23)], ['O', (1.059, -0.437, -0.697)], ['C', (-1.198, 0.194, -0.318)], ['O', (5.148, -2.365, -0.154)], ['O', (5.475, -0.357, 3.604)], ['O', (0.499, -1.683, 5.342)], ['O', (-0.305, -2.723, 2.831)], ['C', (-1.05, -3.316, 3.775)], ['O', (-0.564, -3.881, 4.742)], ['C', (-2.555, -3.283, 3.613)], ['C', (-3.131, -2.573, 2.553)], ['C', (-4.519, -2.531, 2.398)], ['C', (-5.343, -3.21, 3.3)], ['C', (-4.774, -3.919, 4.362)], ['C', (-3.386, -3.955, 4.518)], ['C', (-3.219, 5.62, 0.367)], ['C', (-3.981, 6.179, -0.664)], ['C', (-4.449, 5.372, -1.705)], ['C', (-4.152, 4.007, -1.71)], ['C', (-3.391, 3.452, -0.678)], ['O', (5.075, 1.851, 3.284)], ['H', (7.656, 0.728, 4.637)], ['H', (6.359, 0.859, 5.868)], ['H', (6.901, 2.314, 4.977)], ['H', (-6.335, 3.057, 3.477)], ['H', (-7.113, 0.865, 4.293)], ['H', (-5.475, -0.906, 4.881)], ['H', (-3.046, -0.486, 4.58)], ['H', (-2.258, 1.699, 3.779)], ['H', (1.157, 1.836, 4.142)], ['H', (-0.254, -0.39, 2.686)], ['H', (-0.488, 0.166, 4.333)], ['H', (3.049, -2.247, 6.126)], ['H', (3.418, -2.896, 4.559)], ['H', (4.483, -1.687, 5.296)], ['H', (3.183, 0.036, 6.538)], ['H', (3.066, 1.253, 5.305)], ['H', (1.589, 0.521, 6.005)], ['H', (3.556, 1.242, 0.76)], ['H', (1.871, 1.652, 0.687)], ['H', (2.995, 2.651, 1.65)], ['H', (1.492, -3.432, 3.532)], ['H', (1.789, -1.255, 1.512)], ['H', (2.849, -4.921, 1.053)], ['H', (4.521, -4.633, 1.46)], ['H', (3.301, -4.623, 2.731)], ['H', (4.777, -0.232, 1.716)], ['H', (3.472, -1.28, -0.081)], ['H', (3.206, -2.659, -2.006)], ['H', (3.407, -4.164, -1.109)], ['H', (0.876, -2.865, -1.699)], ['H', (0.597, -4.823, 1.553)], ['H', (-0.79, -4.313, 0.498)], ['H', (-0.297, 3.593, 0.235)], ['H', (-0.508, 5.754, 2.103)], ['H', (-2.211, 2.56, 1.507)], ['H', (-1.856, 3.979, 3.61)], ['H', (-1.756, 0.199, 0.629)], ['H', (-0.85, 1.216, -0.531)], ['H', (-1.869, -0.132, -1.127)], ['H', (5.528, -3.235, -0.126)], ['H', (0.815, -2.49, 5.73)], ['H', (-2.506, -2.044, 1.842)], ['H', (-4.957, -1.972, 1.578)], ['H', (-6.422, -3.185, 3.177)], ['H', (-5.414, -4.443, 5.066)], ['H', (-2.959, -4.507, 5.348)], ['H', (-2.868, 6.259, 1.17)], ['H', (-4.21, 7.24, -0.655)], ['H', (-5.04, 5.805, -2.506)], ['H', (-4.515, 3.378, -2.517)], ['H', (-3.174, 2.389, -0.697)]] + elif mol_str == "bench13": mol_str = [['N', (-4.092316, -1.000975, -0.370529)], ['C', (-4.323298, -0.263368, -1.612788)], ['C', (-3.205279, 0.75529, -1.724502)], ['C', (-4.333065, -1.195831, -2.804809)], ['C', (-5.534893, -2.195586, -2.747894)], ['C', (-4.391021, -0.411072, -4.124619)], ['O', (-2.022551, 0.51326, -1.614742)], ['O', (-3.685654, 1.972297, -2.0328)], ['C', (-2.674446, 3.003316, -2.308876)], ['C', (-2.061835, 3.479071, -0.971537)], ['C', (-3.443134, 4.164203, -2.94218)], ['O', (-0.982616, 4.138519, -1.109289)], ['N', (-2.562361, 3.191384, 0.172976)], ['C', (-1.922153, 3.625377, 1.414867)], ['C', (-0.521384, 3.059577, 1.425755)], ['C', (-2.660285, 3.19546, 2.693552)], ['C', (-4.031723, 3.921459, 2.671174)], ['C', (-1.882265, 3.476688, 3.916073)], ['O', (-0.237963, 1.915325, 1.170709)], ['O', (0.405568, 3.987667, 1.716328)], ['C', (1.781787, 3.582552, 1.870413)], ['C', (2.481074, 3.382195, 0.575318)], ['C', (2.480996, 4.611227, 2.746627)], ['C', (2.47089, 5.974202, 2.063195)], ['C', (1.863565, 4.691095, 4.116433)], ['O', (3.676679, 3.042497, 0.627604)], ['N', (1.882065, 3.63419, -0.5706)], ['C', (2.598479, 3.416843, -1.822129)], ['C', (3.04029, 1.970656, -1.862391)], ['C', (1.746343, 3.771275, -3.063258)], ['C', (1.427287, 5.251532, -3.043648)], ['C', (2.444655, 3.420207, -4.345496)], ['O', (2.308863, 0.99538, -1.690734)], ['O', (4.353206, 1.844014, -2.141135)], ['C', (4.912975, 0.501695, -2.282755)], ['C', (4.866954, -0.251241, -0.969886)], ['C', (6.363862, 0.653789, -2.723359)], ['O', (4.889152, -1.491995, -0.988352)], ['N', (4.852304, 0.411828, 0.187274)], ['C', (4.863721, -0.32053, 1.453267)], ['C', (3.640496, -1.203917, 1.497502)], ['C', (4.913635, 0.624865, 2.635648)], ['C', (6.173744, 1.508032, 2.60033)], ['C', (4.932395, -0.156748, 3.950794)], ['O', (2.51072, -0.84202, 1.236281)], ['O', (3.945608, -2.457328, 1.878505)], ['C', (2.828097, -3.372211, 2.03196)], ['C', (2.445551, -4.040763, 0.77174)], ['C', (3.254346, -4.422058, 3.076046)], ['C', (4.485963, -5.17999, 2.732181)], ['C', (3.328906, -3.808486, 4.458682)], ['O', (1.510141, -4.859848, 0.766457)], ['N', (3.020005, -3.732477, -0.410699)], ['C', (2.468904, -4.259053, -1.6646)], ['C', (1.005575, -3.799167, -1.749489)], ['C', (3.268214, -3.838849, -2.88078)], ['C', (4.693123, -4.426191, -2.823146)], ['C', (2.559891, -4.254352, -4.172416)], ['O', (0.635499, -2.673457, -1.581456)], ['O', (0.19673, -4.821967, -2.053653)], ['C', (-1.207969, -4.498197, -2.24626)], ['C', (-1.919179, -4.182296, -0.920514)], ['C', (-1.832879, -5.699367, -2.974726)], ['O', (-3.059153, -3.717116, -0.961162)], ['N', (-1.303471, -4.547234, 0.218128)], ['C', (-1.982414, -4.214913, 1.486175)], ['C', (-2.200782, -2.712629, 1.511514)], ['C', (-1.134918, -4.689135, 2.671762)], ['C', (-0.944827, -6.191286, 2.623437)], ['C', (-1.829816, -4.316462, 3.967327)], ['O', (-1.385866, -1.9006, 1.279362)], ['O', (-3.460876, -2.442907, 1.878543)], ['C', (-3.804869, -1.044508, 2.117691)], ['C', (-4.050817, -0.299492, 0.783208)], ['C', (-5.016031, -1.022831, 3.027983)], ['C', (-6.211473, -1.582777, 2.357053)], ['C', (-5.252923, 0.363926, 3.607945)], ['O', (-4.233373, 0.913002, 0.812934)], ['H', (-4.825052, -1.677193, -0.281154)], ['H', (-5.217026, 0.186571, -1.599372)], ['H', (-3.479308, -1.715455, -2.769455)], ['H', (-5.508697, -2.792601, -3.549257)], ['H', (-6.393354, -1.68153, -2.7406)], ['H', (-5.468672, -2.74598, -1.915801)], ['H', (-4.39429, -1.049299, -4.894845)], ['H', (-3.59014, 0.186846, -4.194563)], ['H', (-5.225516, 0.143589, -4.147272)], ['H', (-1.941915, 2.666583, -2.901546)], ['H', (-2.806324, 4.905979, -3.158884)], ['H', (-4.135677, 4.492828, -2.303388)], ['H', (-3.887482, 3.852542, -3.781836)], ['H', (-3.485062, 3.580046, 0.182699)], ['H', (-1.927597, 4.624579, 1.428033)], ['H', (-2.79044, 2.201843, 2.715678)], ['H', (-4.555076, 3.671645, 3.487594)], ['H', (-4.539107, 3.649591, 1.856395)], ['H', (-3.884359, 4.909505, 2.658473)], ['H', (-2.397552, 3.181021, 4.719518)], ['H', (-1.703444, 4.460941, 3.979504)], ['H', (-1.011857, 2.981721, 3.879574)], ['H', (1.807192, 2.684489, 2.310092)], ['H', (3.430162, 4.320142, 2.868487)], ['H', (2.930026, 6.643402, 2.6424)], ['H', (1.526351, 6.263032, 1.907788)], ['H', (2.9481, 5.912261, 1.184104)], ['H', (2.346361, 5.375005, 4.664527)], ['H', (1.934535, 3.796949, 4.566903)], ['H', (0.900706, 4.946579, 4.034185)], ['H', (1.064858, 3.062686, -0.602928)], ['H', (3.389388, 4.028458, -1.850492)], ['H', (0.906884, 3.231889, -3.020692)], ['H', (0.875632, 5.481426, -3.847705)], ['H', (2.276787, 5.774253, -3.061772)], ['H', (0.915762, 5.470666, -2.216102)], ['H', (1.859529, 3.668329, -5.120102)], ['H', (2.629183, 2.440334, -4.368647)], ['H', (3.304846, 3.923639, -4.404726)], ['H', (4.379382, -0.01605, -2.953946)], ['H', (6.778208, -0.252132, -2.830428)], ['H', (6.872358, 1.169804, -2.034528)], ['H', (6.401161, 1.139812, -3.594077)], ['H', (5.654846, 1.007647, 0.191146)], ['H', (5.688734, -0.885157, 1.507712)], ['H', (4.0981, 1.195862, 2.574451)], ['H', (6.173595, 2.119258, 3.394792)], ['H', (6.176517, 2.055384, 1.762697)], ['H', (6.990721, 0.931248, 2.624915)], ['H', (4.963944, 0.484942, 4.718965)], ['H', (5.737428, -0.749023, 3.979542)], ['H', (4.10688, -0.714856, 4.021343)], ['H', (2.017946, -2.856984, 2.316729)], ['H', (2.532306, -5.110371, 3.074232)], ['H', (4.687279, -5.836148, 3.4622)], ['H', (5.250346, -4.546334, 2.632862)], ['H', (4.342276, -5.672236, 1.87173)], ['H', (3.607288, -4.506348, 5.120102)], ['H', (2.432946, -3.447169, 4.712392)], ['H', (4.000576, -3.066883, 4.456643)], ['H', (3.00183, -2.736612, -0.48491)], ['H', (2.523818, -5.256397, -1.663438)], ['H', (3.338853, -2.839418, -2.879074)], ['H', (5.208256, -4.142571, -3.633434)], ['H', (4.639769, -5.426967, -2.799486)], ['H', (5.158021, -4.099125, -2.001722)], ['H', (3.10569, -3.967766, -4.960551)], ['H', (1.661957, -3.821054, -4.213784)], ['H', (2.450854, -5.248405, -4.18741)], ['H', (-1.307259, -3.664356, -2.791012)], ['H', (-2.80549, -5.528762, -3.131782)], ['H', (-1.727941, -6.522351, -2.411801)], ['H', (-1.373006, -5.834787, -3.850316)], ['H', (-1.185777, -5.539039, 0.189424)], ['H', (-2.868035, -4.680768, 1.554632)], ['H', (-0.239333, -4.245968, 2.624327)], ['H', (-0.389239, -6.483047, 3.400717)], ['H', (-0.481563, -6.441532, 1.772248)], ['H', (-1.838177, -6.643402, 2.663197)], ['H', (-1.275974, -4.622937, 4.742494)], ['H', (-2.727713, -4.754904, 4.006495)], ['H', (-1.943531, -3.32249, 4.015512)], ['H', (-3.048477, -0.563112, 2.562609)], ['H', (-4.828053, -1.619266, 3.807471)], ['H', (-6.990721, -1.552747, 2.981064)], ['H', (-6.033213, -2.529521, 2.093305)], ['H', (-6.424515, -1.044284, 1.539849)], ['H', (-6.059472, 0.344963, 4.201915)], ['H', (-5.404125, 1.013925, 2.864268)], ['H', (-4.452484, 0.643772, 4.140208)]] elif mol_str.split('_')[0].lower() == "c": num_c_atoms = int(mol_str.split('_')[1]) mol_str = [["C", (0, 0, i)] for i in range(num_c_atoms)] @@ -224,3 +237,18 @@ def prepare(val): writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg')) writer.close() print("Numerical error visualisation saved in", gif_path) + + +from tessellate_ipu import tile_map, ipu_cycle_count, tile_put_sharded +from typing import List + +def get_ipu_cycles(data_to_be_sharded: List[float], num_items_to_be_sharded: int = 1) -> List[float]: + tmp = data_to_be_sharded[0:num_items_to_be_sharded] + tiles = tuple(range(len(tmp))) + tmp = tile_put_sharded(tmp, tiles) + tmp, cycles_count = ipu_cycle_count(tmp) + tmp = tmp.array + for idx in tiles: + data_to_be_sharded = data_to_be_sharded.at[idx].set(tmp[idx]) + + return data_to_be_sharded, cycles_count \ No newline at end of file diff --git a/requirements_test.txt b/requirements_test.txt index eeac8a08..ac4d3561 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -11,3 +11,4 @@ pre-commit flake8 flake8-copyright isort +tessellate_ipu diff --git a/test/test_benchmark_performance.py b/test/test_benchmark_performance.py new file mode 100644 index 00000000..78548493 --- /dev/null +++ b/test/test_benchmark_performance.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import jax +import numpy as np +import pytest +from tessellate_ipu import ipu_cycle_count, tile_map, tile_put_sharded + +from pyscf_ipu.nanoDFT.nanoDFT import build_mol, nanoDFT, nanoDFT_options + + +@pytest.mark.skip(reason="Skipping IPU test in CI!") +@pytest.mark.ipu +def test_basic_demonstration(): + dummy = np.random.rand(2, 3).astype(np.float32) + dummier = np.random.rand(2, 3).astype(np.float32) + + @jax.jit + def jitted_inner_test(dummy, dummier): + tiles = tuple(range(len(dummy))) + dummy = tile_put_sharded(dummy, tiles) + tiles = tuple(range(len(dummier))) + dummier = tile_put_sharded(dummier, tiles) + + dummy, dummier, start = ipu_cycle_count(dummy, dummier) + out = tile_map(jax.lax.add_p, dummy, dummier) + out, end = ipu_cycle_count(out) + + return out, start, end + + _, start, end = jitted_inner_test(dummy, dummier) + print("Start cycle count:", start, start.shape) + print("End cycle count:", end, end.shape) + print("Diff cycle count:", end.array - start.array) + + assert True + + +@pytest.mark.skip(reason="Skipping IPU test in CI!") +@pytest.mark.ipu +@pytest.mark.parametrize("molecule", ["methane", "benzene"]) +def test_dense_eri(molecule): + opts, mol_str = nanoDFT_options(float32=True, mol_str=molecule, backend="ipu") + mol = build_mol(mol_str, opts.basis) + + _, _, ipu_cycles_stamps = nanoDFT(mol, opts, profile_performance=True) + + start, end = ipu_cycles_stamps + start = np.asarray(start) + end = np.asarray(end) + + diff = (end - start)[0][0][0] + print( + "----------------------------------------------------------------------------" + ) + print(" Diff cycle count:", diff) + print(" Diff cycle count [M]:", diff / 1e6) + print("Estimated time of execution on Bow-IPU [seconds]:", diff / (1.85 * 1e9)) + print( + "----------------------------------------------------------------------------" + ) + + assert True + + +@pytest.mark.skip(reason="Skipping IPU test in CI!") +@pytest.mark.ipu +@pytest.mark.parametrize("molecule", ["methane", "benzene", "c20"]) +def test_sparse_eri(molecule): + opts, mol_str = nanoDFT_options( + float32=True, + mol_str=molecule, + backend="ipu", + dense_ERI=False, + eri_threshold=1e-9, + ) + mol = build_mol(mol_str, opts.basis) + + _, _, ipu_cycles_stamps = nanoDFT(mol, opts, profile_performance=True) + + start, end = ipu_cycles_stamps + start = np.asarray(start) + end = np.asarray(end) + + diff = (end - start)[0][0][0] + print( + "----------------------------------------------------------------------------" + ) + print(" Diff cycle count:", diff) + print(" Diff cycle count [M]:", diff / 1e6) + print("Estimated time of execution on Bow-IPU [seconds]:", diff / (1.85 * 1e9)) + print( + "----------------------------------------------------------------------------" + ) + + assert True