Skip to content

Commit

Permalink
Clean up DIIS
Browse files Browse the repository at this point in the history
  • Loading branch information
obackhouse committed Sep 18, 2024
1 parent 7af8b09 commit 3900bae
Showing 1 changed file with 87 additions and 82 deletions.
169 changes: 87 additions & 82 deletions ebcc/core/damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
T = float64


class DIIS(diis.DIIS):
class DIIS:
"""Direct inversion in the iterative subspace.
Adapted from PySCF.
Notes:
This code is adapted from PySCF.
"""

_head: int
_buffer: dict[str, NDArray[T]]
_bookkeep: list[int]
_err_vec_touched: bool
_H: Optional[NDArray[T]]
_xprev: Optional[NDArray[T]]
# Intermediates
_index: int
_indices: list[int]
_arrays: dict[int, NDArray[T]]
_errors: dict[int, NDArray[T]]
_matrix: Optional[NDArray[T]]

def __init__(self, space: int = 6, min_space: int = 1, damping: float = 0.0) -> None:
"""Initialize the DIIS object.
Expand All @@ -39,120 +40,124 @@ def __init__(self, space: int = 6, min_space: int = 1, damping: float = 0.0) ->
min_space: The minimum number of vectors to store in the DIIS space.
damping: The damping factor to apply to the extrapolated vector.
"""
super().__init__(incore=True)
self.verbose = 0
# Options
self.space = space
self.min_space = min_space
self.damping = damping

def _store(self, key: str, value: NDArray[T]) -> None:
"""Store the given values in the DIIS buffer."""
self._buffer[key] = value

def push_err_vec(self, xerr: NDArray[T]) -> None:
"""Push the error vectors into the DIIS subspace."""
self._err_vec_touched = True
if self._head >= self.space:
self._head = 0
self._store(f"e{self._head}", xerr)

def push_vec(self, x: NDArray[T]) -> None:
"""Push the vectors into the DIIS subspace."""
if len(self._bookkeep) >= self.space:
self._bookkeep = self._bookkeep[1 - self.space :]

if self._err_vec_touched:
self._bookkeep.append(self._head)
self._store(f"x{self._head}", x)
self._head += 1
elif self._xprev is None:
self._xprev = x
self._store("xprev", x)
else:
if self._head >= self.space:
self._head = 0
self._bookkeep.append(self._head)
self._store(f"x{self._head}", x)
self._store(f"e{self._head}", x - self._xprev)
self._head += 1
# Intermediates
self._index = 0
self._indices = []
self._arrays = {}
self._errors = {}
self._matrix = None

def get_err_vec(self, idx: int) -> NDArray[T]:
"""Get the error vectors at the given index."""
return self._buffer[f"e{idx}"]
def push(self, x: NDArray[T], xerr: Optional[NDArray[T]] = None) -> None:
"""Push the vectors and error vectors into the DIIS subspace.
def get_vec(self, idx: int) -> NDArray[T]:
"""Get the vectors at the given index."""
return self._buffer[f"x{idx}"]
Args:
x: The vector to push into the DIIS subspace.
xerr: The error vector to push into the DIIS subspace.
"""
if len(self._indices) >= self.space:
self._indices = self._indices[1 - self.space:]

def get_num_vec(self) -> int:
"""Get the number of vector groups stored in the DIIS object."""
return len(self._bookkeep)
if xerr is not None:
if self._index >= self.space:
self._index = 0
self._errors[self._index] = xerr
self._indices.append(self._index)
self._arrays[self._index] = x
self._index += 1
elif -1 not in self._arrays:
self._arrays[-1] = x
else:
if self._index >= self.space:
self._index = 0
self._indices.append(self._index)
self._arrays[self._index] = x
self._errors[self._index] = x - self._arrays[-1]
self._index += 1

@property
def narrays(self) -> int:
"""Get the number of arrays stored in the DIIS object."""
return len(self._indices)

def update(self, x: NDArray[T], xerr: Optional[NDArray[T]] = None) -> NDArray[T]:
"""Extrapolate a vector."""
"""Extrapolate a vector.
Args:
x: The vector to extrapolate.
xerr: The error vector to extrapolate.
Returns:
The extrapolated vector.
"""
# Push the vector and error vector into the DIIS subspace
if xerr is not None:
self.push_err_vec(xerr)
self.push_vec(x)
self.push(x, xerr)

# Check if the DIIS space is less than the minimum space
nd = self.get_num_vec()
nd = self.narrays
if nd < self.min_space:
return x

# Build the error matrix
x1 = self.get_err_vec(self._head - 1)
if self._H is None:
self._H = np.block(
x1 = self._errors[self._index - 1]
if self._matrix is None:
self._matrix = np.block(
[
[np.zeros((1, 1)), np.ones((1, self.space))],
[np.ones((self.space, 1)), np.zeros((self.space, self.space))],
]
)
# this looks crazy, but it's just updating the `self._head`th row and
# this looks crazy, but it's just updating the `self._index`th row and
# column with the new errors, it's just done this way to avoid using
# calls to `__setitem__` in immutable backends
Hi = np.array([np.dot(x1.ravel().conj(), self.get_err_vec(i).ravel()) for i in range(nd)])
Hi = np.concatenate([np.array([1.0]), Hi, np.zeros(self.space - nd)])
Hi = Hi.reshape(-1, 1)
Hj = Hi.T.conj()
pre = slice(0, self._head)
pos = slice(self._head + 1, self.space + 1)
self._H = np.block(
m_i = np.array([np.dot(x1.ravel().conj(), self._errors[i].ravel()) for i in range(nd)])
m_i = np.concatenate([np.array([1.0]), m_i, np.zeros(self.space - nd)])
m_i = m_i.reshape(-1, 1)
m_j = m_i.T.conj()
pre = slice(0, self._index)
pos = slice(self._index + 1, self.space + 1)
self._matrix = np.block(
[
[self._H[pre, pre], Hi[pre, :], self._H[pre, pos]],
[Hj[:, pre], Hi[self._head, :].reshape(1, 1), Hj[:, pos]],
[self._H[pos, pre], Hi[pos, :], self._H[pos, pos]],
[self._matrix[pre, pre], m_i[pre, :], self._matrix[pre, pos]],
[m_j[:, pre], m_i[self._index, :].reshape(1, 1), m_j[:, pos]],
[self._matrix[pos, pre], m_i[pos, :], self._matrix[pos, pos]],
]
)

if self._xprev is None:
xnew = self.extrapolate(nd)
else:
self._xprev = None # release memory first
self._xprev = xnew = self.extrapolate(nd)
self._store("xprev", xnew)
xnew = self.extrapolate(nd)
self._arrays[-1] = xnew

# Apply damping
if self.damping:
nd = self.get_num_vec()
nd = self.narrays
if nd > 1:
xprev = self.get_vec(self.get_num_vec() - 1)
xprev = self._arrays[self.narrays - 1]
xnew = (1.0 - self.damping) * xnew + self.damping * xprev

return xnew

def extrapolate(self, nd: Optional[int] = None) -> NDArray[T]:
"""Extrapolate the next vector."""
"""Extrapolate the next vector.
Args:
nd: The number of arrays to use in the extrapolation.
Returns:
The extrapolated vector.
"""
if nd is None:
nd = self.get_num_vec()
nd = self.narrays
if nd == 0:
raise RuntimeError("No vector found in DIIS object.")

# Get the linear problem to solve
if self._H is None:
if self._matrix is None:
raise RuntimeError("DIIS object not initialised.")
h = self._H[: nd + 1, : nd + 1]
h = self._matrix[: nd + 1, : nd + 1]
g = np.concatenate([np.ones((1,), h.dtype), np.zeros((nd,), h.dtype)])

# Solve the linear problem
Expand All @@ -161,9 +166,9 @@ def extrapolate(self, nd: Optional[int] = None) -> NDArray[T]:
c = util.einsum("pi,qi,i,q->p", v[:, mask], v[:, mask].conj(), 1 / w[mask], g)

# Construct the new vector
xnew: NDArray[T] = np.zeros_like(self.get_vec(0))
xnew: NDArray[T] = np.zeros_like(self._arrays[0])
for i, ci in enumerate(c[1:]):
xi = self.get_vec(i)
xi = self._arrays[i]
xnew += xi * ci

return xnew

0 comments on commit 3900bae

Please sign in to comment.