Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up ham module #87

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions ebcc/cc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ebcc.core.dump import Dump
from ebcc.core.logging import ANSI
from ebcc.core.precision import astype, types
from ebcc.ham.base import BaseERIs
from ebcc.util import _BaseOptions

if TYPE_CHECKING:
Expand All @@ -26,7 +27,7 @@

from ebcc.core.damping import BaseDamping
from ebcc.core.logging import Logger
from ebcc.ham.base import BaseElectronBoson, BaseERIs, BaseFock
from ebcc.ham.base import BaseElectronBoson, BaseFock
from ebcc.opt.base import BaseBruecknerEBCC
from ebcc.util import Namespace

Expand Down Expand Up @@ -891,16 +892,21 @@ def init_space(self) -> SpaceType:
"""
pass

@abstractmethod
def get_fock(self) -> BaseFock:
"""Get the Fock matrix.

Returns:
Fock matrix.
"""
pass
return self.Fock(
self.mf,
space=(self.space, self.space),
mo_coeff=(self.mo_coeff, self.mo_coeff),
g=self.g,
shift=self.options.shift,
xi=self.xi if self.boson_ansatz else None,
)

@abstractmethod
def get_eris(self, eris: Optional[ERIsInputType] = None) -> BaseERIs:
"""Get the electron repulsion integrals.

Expand All @@ -910,7 +916,23 @@ def get_eris(self, eris: Optional[ERIsInputType] = None) -> BaseERIs:
Returns:
Electron repulsion integrals.
"""
pass
use_df = getattr(self.mf, "with_df", None) is not None
if isinstance(eris, BaseERIs):
return eris
elif eris is not None:
raise TypeError(f"`eris` must be an `BaseERIs` object, got {eris.__class__.__name__}.")
elif use_df:
return self.CDERIs(
self.mf,
space=(self.space, self.space, self.space, self.space),
mo_coeff=(self.mo_coeff, self.mo_coeff, self.mo_coeff, self.mo_coeff),
)
else:
return self.ERIs(
self.mf,
space=(self.space, self.space, self.space, self.space),
mo_coeff=(self.mo_coeff, self.mo_coeff, self.mo_coeff, self.mo_coeff),
)

def get_g(self) -> BaseElectronBoson:
"""Get the blocks of the electron-boson coupling matrix.
Expand All @@ -920,7 +942,14 @@ def get_g(self) -> BaseElectronBoson:
Returns:
Electron-boson coupling matrix.
"""
return self.ElectronBoson(self, array=self.bare_g)
if self.bare_g is None:
raise ValueError("Bare electron-boson coupling matrix not provided.")
return self.ElectronBoson(
self.mf,
self.bare_g,
(self.space, self.space),
(self.mo_coeff, self.mo_coeff),
)

@abstractmethod
def get_mean_field_G(self) -> Any:
Expand All @@ -931,18 +960,6 @@ def get_mean_field_G(self) -> Any:
"""
pass

@property
@abstractmethod
def bare_fock(self) -> Any:
"""Get the mean-field Fock matrix in the MO basis, including frozen parts.

Returns an array and not a `BaseFock` object.

Returns:
Mean-field Fock matrix.
"""
pass

@property
@abstractmethod
def xi(self) -> NDArray[T]:
Expand Down
35 changes: 0 additions & 35 deletions ebcc/cc/gebcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,19 +856,6 @@ def get_mean_field_G(self) -> NDArray[T]:
val += self.bare_G
return val

@property
def bare_fock(self) -> NDArray[T]:
"""Get the mean-field Fock matrix in the MO basis, including frozen parts.

Returns an array and not a `BaseFock` object.

Returns:
Mean-field Fock matrix.
"""
fock_ao: NDArray[T] = np.asarray(self.mf.get_fock(), dtype=types[float])
fock = util.einsum("pq,pi,qj->ij", fock_ao, self.mo_coeff, self.mo_coeff)
return fock

@property
def xi(self) -> NDArray[T]:
"""Get the shift in the bosonic operators to diagonalise the photon Hamiltonian.
Expand All @@ -888,28 +875,6 @@ def xi(self) -> NDArray[T]:
xi = np.zeros(self.omega.shape)
return xi

def get_fock(self) -> GFock:
"""Get the Fock matrix.

Returns:
Fock matrix.
"""
return self.Fock(self, array=self.bare_fock, g=self.g)

def get_eris(self, eris: Optional[ERIsInputType] = None) -> GERIs:
"""Get the electron repulsion integrals.

Args:
eris: Input electron repulsion integrals.

Returns:
Electron repulsion integrals.
"""
if isinstance(eris, GERIs):
return eris
else:
return self.ERIs(self, array=eris)

@property
def nmo(self) -> int:
"""Get the number of molecular orbitals.
Expand Down
38 changes: 0 additions & 38 deletions ebcc/cc/rebcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,19 +614,6 @@ def get_mean_field_G(self) -> NDArray[T]:
val += self.bare_G
return val

@property
def bare_fock(self) -> NDArray[T]:
"""Get the mean-field Fock matrix in the MO basis, including frozen parts.

Returns an array and not a `BaseFock` object.

Returns:
Mean-field Fock matrix.
"""
fock_ao: NDArray[T] = np.asarray(self.mf.get_fock(), dtype=types[float])
fock = util.einsum("pq,pi,qj->ij", fock_ao, self.mo_coeff, self.mo_coeff)
return fock

@property
def xi(self) -> NDArray[T]:
"""Get the shift in the bosonic operators to diagonalise the photon Hamiltonian.
Expand All @@ -646,31 +633,6 @@ def xi(self) -> NDArray[T]:
xi = np.zeros(self.omega.shape, dtype=types[float])
return xi

def get_fock(self) -> RFock:
"""Get the Fock matrix.

Returns:
Fock matrix.
"""
return self.Fock(self, array=self.bare_fock, g=self.g)

def get_eris(self, eris: Optional[ERIsInputType] = None) -> Union[RERIs, RCDERIs]:
"""Get the electron repulsion integrals.

Args:
eris: Input electron repulsion integrals.

Returns:
Electron repulsion integrals.
"""
use_df = getattr(self.mf, "with_df", None) is not None
if isinstance(eris, (RERIs, RCDERIs)):
return eris
elif (isinstance(eris, np.ndarray) and eris.ndim == 3) or use_df:
return self.CDERIs(self, array=eris)
else:
return self.ERIs(self, array=eris)

@property
def nmo(self) -> int:
"""Get the number of molecular orbitals.
Expand Down
45 changes: 0 additions & 45 deletions ebcc/cc/uebcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,24 +809,6 @@ def get_mean_field_G(self) -> NDArray[T]:
val += self.bare_G
return val

@property
def bare_fock(self) -> Namespace[NDArray[T]]:
"""Get the mean-field Fock matrix in the MO basis, including frozen parts.

Returns an array and not a `BaseFock` object.

Returns:
Mean-field Fock matrix.
"""
fock_array = util.einsum(
"npq,npi,nqj->nij",
np.asarray(self.mf.get_fock(), dtype=types[float]),
self.mo_coeff,
self.mo_coeff,
)
fock = util.Namespace(aa=fock_array[0], bb=fock_array[1])
return fock

@property
def xi(self) -> NDArray[T]:
"""Get the shift in the bosonic operators to diagonalise the photon Hamiltonian.
Expand All @@ -847,33 +829,6 @@ def xi(self) -> NDArray[T]:
xi = np.zeros(self.omega.shape)
return xi

def get_fock(self) -> UFock:
"""Get the Fock matrix.

Returns:
Fock matrix.
"""
return self.Fock(self, array=(self.bare_fock.aa, self.bare_fock.bb), g=self.g)

def get_eris(self, eris: Optional[ERIsInputType] = None) -> Union[UERIs, UCDERIs]:
"""Get the electron repulsion integrals.

Args:
eris: Input electron repulsion integrals.

Returns:
Electron repulsion integrals.
"""
use_df = getattr(self.mf, "with_df", None) is not None
if isinstance(eris, (UERIs, UCDERIs)):
return eris
elif (
isinstance(eris, tuple) and isinstance(eris[0], np.ndarray) and eris[0].ndim == 3
) or use_df:
return self.CDERIs(self, array=eris)
else:
return self.ERIs(self, array=eris)

@property
def nmo(self) -> int:
"""Get the number of molecular orbitals.
Expand Down
Loading
Loading