diff --git a/src/hwave/solver/uhfr.py b/src/hwave/solver/uhfr.py index 03cf0f3..000d0e8 100644 --- a/src/hwave/solver/uhfr.py +++ b/src/hwave/solver/uhfr.py @@ -59,6 +59,77 @@ def _transform_interall(self, ham_info): """ return ham_info + def _calc_hartree(self): + """Calculate Hartree terms in the Hamiltonian. + + Calculates the diagonal Hartree terms in the Hamiltonian by iterating through + interaction parameters and adding contributions to Ham_tmp and Ham_trans_tmp. + + Parameters + ---------- + None + + Returns + ------- + None + Updates self.Ham_tmp and self.Ham_trans_tmp arrays + """ + site = np.zeros(4, dtype=np.int32) + for site_info, value in self.param_ham.items(): + for i in range(4): + site[i] = site_info[2 * i] + site_info[2 * i + 1] * self.Nsize + # Diagonal Fock term + self.Ham_tmp[site[0]][site[1]][site[2]][site[3]] += value + self.Ham_tmp[site[2]][site[3]][site[0]][site[1]] += value + if site[1] == site[2]: + self.Ham_trans_tmp[site[1]][site[2]] += value + pass + + def _calc_fock(self): + """Calculate Fock exchange terms in the Hamiltonian. + + Calculates the off-diagonal Fock exchange terms in the Hamiltonian by iterating + through interaction parameters and adding contributions to Ham_tmp. + + Parameters + ---------- + None + + Returns + ------- + None + Updates self.Ham_tmp array + """ + site = np.zeros(4,dtype=np.int32) + for site_info, value in self.param_ham.items(): + for i in range(4): + site[i] = site_info[2 * i] + site_info[2 * i + 1] * self.Nsize + # OffDiagonal Fock term + self.Ham_tmp[site[0]][site[3]][site[2]][site[1]] -= value + self.Ham_tmp[site[2]][site[1]][site[0]][site[3]] -= value + pass + + def get_ham(self, type): + """Get the Hamiltonian matrices. + + Calculates and returns the Hartree and Fock terms of the Hamiltonian. + + Parameters + ---------- + type : str + Type of calculation - either "hartree" or "hartreefock" + + Returns + ------- + tuple of ndarray + (Ham_tmp, Ham_trans_tmp) containing the Hamiltonian matrices + """ + self._calc_hartree() + if type == "hartreefock": + self._calc_fock() + return self.Ham_tmp, self.Ham_trans_tmp + + def _check_range(self): """Check that site indices are within valid range.""" err = 0