Skip to content

Commit

Permalink
Extend functionality of CgInfluence:
Browse files Browse the repository at this point in the history
* add pre-conditioning, add Nystroem- and JacobiPreConditioner
* add stable block cg variant, to solve several rhs simultaneously
  • Loading branch information
Kristof Schroeder committed Mar 4, 2024
1 parent 568426a commit 2e3694f
Show file tree
Hide file tree
Showing 7 changed files with 509 additions and 23 deletions.
3 changes: 3 additions & 0 deletions src/pydvl/influence/torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@
"create_per_sample_mixed_derivative_function",
"model_hessian_low_rank",
"LowRankProductRepresentation",
"randomized_nystroem_approximation",
"model_hessian_nystroem_approximation",
]


logger = logging.getLogger(__name__)


Expand Down
173 changes: 159 additions & 14 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
model_hessian_low_rank,
model_hessian_nystroem_approximation,
)
from .pre_conditioner import (
JacobiPreConditioner,
NystroemPreConditioner,
PreConditioner,
)
from .util import (
EkfacRepresentation,
empirical_cross_entropy_loss_fn,
Expand Down Expand Up @@ -466,8 +471,12 @@ def __init__(
maxiter: Optional[int] = None,
progress: bool = False,
precompute_grad: bool = False,
pre_conditioner: Optional[PreConditioner] = None,
use_block_cg: bool = False,
):
super().__init__(model, loss)
self.use_block_cg = use_block_cg
self.pre_conditioner = pre_conditioner
self.precompute_grad = precompute_grad
self.progress = progress
self.maxiter = maxiter
Expand All @@ -487,6 +496,25 @@ def is_fitted(self):

def fit(self, data: DataLoader) -> CgInfluence:
self.train_dataloader = data
if self.pre_conditioner is not None:

hvp = create_hvp_function(
self.model,
self.loss,
self.train_dataloader,
precompute_grad=self.precompute_grad,
)

def model_hessian_mat_mat_prod(x: torch.Tensor):
return torch.func.vmap(hvp, in_dims=1, randomness="same")(x).t()

self.pre_conditioner.fit(
model_hessian_mat_mat_prod,
self.n_parameters,
self.model_dtype,
self.model_device,
self.hessian_regularization,
)
return self

@log_duration
Expand Down Expand Up @@ -537,9 +565,16 @@ def influences(

@log_duration
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:

if len(self.train_dataloader) == 0:
raise ValueError("Training dataloader must not be empty.")

if rhs.ndim == 1 or rhs.shape[0] == 1:
print("")

if self.use_block_cg:
return self._solve_pbcg(rhs)

hvp = create_hvp_function(
self.model,
self.loss,
Expand All @@ -550,30 +585,143 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
def reg_hvp(v: torch.Tensor):
return hvp(v) + self.hessian_regularization * v.type(rhs.dtype)

y_norm = torch.linalg.norm(rhs, dim=0)

stopping_val = torch.clamp(self.rtol**2 * y_norm, min=self.atol**2)

batch_cg = torch.zeros_like(rhs)
cg_fun = self._solve_cg if self.pre_conditioner is None else self._solve_pcg

for idx, bi in enumerate(
tqdm(rhs, disable=not self.progress, desc="Conjugate gradient")
for idx, (bi, _tol) in enumerate(
tqdm(
zip(rhs, stopping_val),
disable=not self.progress,
desc="Conjugate gradient",
)
):
batch_result = self._solve_cg(
batch_result = cg_fun(
reg_hvp,
bi,
tol=_tol,
x0=self.x0,
rtol=self.rtol,
atol=self.atol,
maxiter=self.maxiter,
)
batch_cg[idx] = batch_result

return batch_cg

def _solve_pcg(
self,
hvp: Callable[[torch.Tensor], torch.Tensor],
b: torch.Tensor,
*,
tol: float,
x0: Optional[torch.Tensor] = None,
maxiter: Optional[int] = None,
):

assert self.pre_conditioner is not None

if x0 is None:
x0 = torch.clone(b)
if maxiter is None:
maxiter = len(b) * 10

x = x0

r0 = (b - hvp(x)).squeeze()
p = z0 = self.pre_conditioner.solve(r0).squeeze()

for k in range(maxiter):
if torch.norm(r0) < tol:
break
Ap = hvp(p).squeeze()
alpha = torch.dot(r0, z0) / torch.dot(p, Ap)
x += alpha * p
r = r0 - alpha * Ap
z = self.pre_conditioner.solve(r)
beta = torch.dot(r, z) / torch.dot(r0, z0)
r0 = r
p = z + beta * p
z0 = z

return x

def _solve_pbcg(
self,
rhs: torch.Tensor,
):
hvp = create_hvp_function(
self.model,
self.loss,
self.train_dataloader,
precompute_grad=self.precompute_grad,
)

def mat_mat(x: torch.Tensor):
return torch.vmap(
lambda u: hvp(u) + self.hessian_regularization * u,
in_dims=1,
randomness="same",
)(x)

X = torch.clone(rhs.T)

R = (rhs - mat_mat(X)).T
Z = R if self.pre_conditioner is None else self.pre_conditioner.solve(R)
P, _, _ = torch.linalg.svd(Z, full_matrices=False)
active_indices = torch.as_tensor(list(range(X.shape[-1])), dtype=torch.long)

maxiter = self.maxiter if self.maxiter is not None else len(rhs) * 10
y_norm = torch.linalg.norm(rhs, dim=1)
tol = torch.clamp(self.rtol**2 * y_norm, min=self.atol**2)

shrink_finished_indices = rhs.shape[0] <= rhs.shape[1]

for k in range(maxiter):
Q = mat_mat(P).T
p_t_ap = P.T @ Q
alpha = torch.linalg.solve(p_t_ap, P.T @ R)
X[:, active_indices] += P @ alpha
R -= Q @ alpha

B = torch.linalg.norm(R, dim=0)
non_finished_indices = torch.nonzero(B > tol)
num_remaining_indices = non_finished_indices.numel()
non_finished_indices = non_finished_indices.squeeze()

if num_remaining_indices == 1:
non_finished_indices = non_finished_indices.unsqueeze(-1)

if num_remaining_indices == 0:
break

if shrink_finished_indices:
active_indices = active_indices[non_finished_indices]
R = R[:, non_finished_indices]
P = P[:, non_finished_indices]
Q = Q[:, non_finished_indices]
p_t_ap = p_t_ap[:, non_finished_indices][non_finished_indices, :]
tol = tol[non_finished_indices]

Z = R if self.pre_conditioner is None else self.pre_conditioner.solve(R)
beta = -torch.linalg.solve(p_t_ap, Q.T @ Z)
Z_tmp = Z + P @ beta

if Z_tmp.ndim == 1:
Z_tmp = Z_tmp.unsqueeze(-1)

P, _, _ = torch.linalg.svd(Z_tmp, full_matrices=False)

return X.T

@staticmethod
def _solve_cg(
hvp: Callable[[torch.Tensor], torch.Tensor],
b: torch.Tensor,
*,
tol: float,
x0: Optional[torch.Tensor] = None,
rtol: float = 1e-7,
atol: float = 1e-7,
maxiter: Optional[int] = None,
) -> torch.Tensor:
r"""
Expand All @@ -596,21 +744,18 @@ def _solve_cg(
if maxiter is None:
maxiter = len(b) * 10

y_norm = torch.sum(torch.matmul(b, b)).item()
stopping_val = max([rtol**2 * y_norm, atol**2])

x = x0
p = r = (b - hvp(x)).squeeze()
gamma = torch.sum(torch.matmul(r, r)).item()
gamma = torch.sum(torch.matmul(r, r))[()]

for k in range(maxiter):
if gamma < stopping_val:
if gamma < tol:
break
Ap = hvp(p).squeeze()
alpha = gamma / torch.sum(torch.matmul(p, Ap)).item()
alpha = gamma / torch.sum(torch.matmul(p, Ap))[()]
x += alpha * p
r -= alpha * Ap
gamma_ = torch.sum(torch.matmul(r, r)).item()
gamma_ = torch.sum(torch.matmul(r, r))[()]
beta = gamma_ / gamma
gamma = gamma_
p = r + beta * p
Expand Down
Loading

0 comments on commit 2e3694f

Please sign in to comment.