Skip to content

Commit

Permalink
Modularize influence computations
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Jan 11, 2024
1 parent c87fa6f commit 0e1e96c
Showing 1 changed file with 148 additions and 37 deletions.
185 changes: 148 additions & 37 deletions analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Dict, Optional, Tuple
import pandas as pd
import torch

from einops import einsum, rearrange, reduce
from analog.utils import get_logger
from analog.utils import get_logger, nested_dict
from analog.analysis import AnalysisBase
from analog.analysis.utils import reconstruct_grad, do_decompose, rescaled_dot_product


class InfluenceFunction(AnalysisBase):
Expand All @@ -16,9 +16,20 @@ def parse_config(self):
return

@torch.no_grad()
def precondition(self, src_log, damping=None):
def precondition(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
damping: Optional[float] = None,
):
"""
Precondition gradients using the Hessian.
Args:
src_log (Dict[str, Dict[str, torch.Tensor]]): Log of source gradients
damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None.
"""
src_ids, src = src_log
preconditioned = {}
preconditioned = nested_dict()
(
covariance_eigval,
covariance_eigvec,
Expand Down Expand Up @@ -50,7 +61,7 @@ def precondition(self, src_log, damping=None):
if damping is None:
damping = 0.1 * torch.mean(scale)
prec_rotated_grad = rotated_grad / (scale + damping)
preconditioned[module_name] = einsum(
preconditioned[module_name]["grad"] = einsum(
module_eigvec["backward"].to(device=device),
prec_rotated_grad,
module_eigvec["forward"].to(device=device).t(),
Expand All @@ -59,16 +70,17 @@ def precondition(self, src_log, damping=None):
return (src_ids, preconditioned)

@torch.no_grad()
def compute_influence(self, src_log, tgt_log, precondition=True, damping=None):
if precondition:
src_log = self.precondition(src_log, damping)
src_ids, src = src_log
tgt_ids, tgt = tgt_log

# Compute influence scores
total_influence = 0.0
def dot(self, src: Dict[str, torch.Tensor], tgt: Dict[str, torch.Tensor]):
"""
Compute the dot product between two gradient dictionaries.
Args:
src (Dict[str, torch.Tensor]): Dictionary of source gradients
tgt (Dict[str, torch.Tensor]): Dictionary of target gradients
"""
results = 0
for module_name in src.keys():
src_module, tgt_module = src[module_name], tgt[module_name]["grad"]
src_module, tgt_module = src[module_name]["grad"], tgt[module_name]["grad"]
tgt_module = tgt_module.to(device=src_module.device)
assert src_module.shape[1:] == tgt_module.shape[1:]
src_module_expanded = rearrange(src_module, "n ... -> n 1 ...")
Expand All @@ -78,60 +90,159 @@ def compute_influence(self, src_log, tgt_log, precondition=True, damping=None):
"n m a b -> n m",
"sum",
)
total_influence += module_influence
results += module_influence
return results

@torch.no_grad()
def norm(
self,
src: Dict[str, torch.Tensor],
tgt: Optional[Dict[str, torch.Tensor]] = None,
):
"""
Compute the norm of a gradient dictionary.
Args:
src (Dict[str, torch.Tensor]): Dictionary of source gradients
tgt (Optional[Dict[str, torch.Tensor]]): Dictionary of target gradients
"""
results = 0
for module_name in src.keys():
src_module = src[module_name]["grad"]
tgt_module = tgt[module_name]["grad"] if tgt is not None else src_module
module_influence = reduce(src_module * tgt_module, "n a b -> n", "sum")
results += module_influence.reshape(-1)
return results

def compute_influence(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
tgt_log: Dict[str, Dict[str, torch.Tensor]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
damping: Optional[float] = None,
):
"""
Compute influence scores between two gradient dictionaries.
Args:
src_log (Dict[str, Dict[str, torch.Tensor]]): Log of source gradients
tgt_log (Dict[str, Dict[str, torch.Tensor]]): Log of target gradients
mode (Optional[str], optional): Influence function mode. Defaults to "dot".
precondition (Optional[bool], optional): Whether to precondition the gradients. Defaults to True.
damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None.
"""
assert mode in ["dot", "l2", "cosine"], f"Invalid mode: {mode}"

if precondition:
src_log = self.precondition(src_log, damping)

src_ids, src = src_log
tgt_ids, tgt = tgt_log

# Compute influence scores
total_influence = None
if mode == "dot":
total_influence = self.dot(src, tgt)
elif mode == "cosine":
dot = self.dot(src, tgt)
src_norm = self.norm(src)
tgt_norm = self.norm(tgt).to(device=src_norm.device)
total_influence = dot / torch.sqrt(torch.outer(src_norm, tgt_norm))
elif mode == "l2":
dot = self.dot(src, tgt)
src_norm = self.norm(src)
tgt_norm = self.norm(tgt).to(device=src_norm.device)
total_influence = 2 * dot - src_norm.unsqueeze(1) - tgt_norm.unsqueeze(0)
total_influence = total_influence.cpu()

# Log influence scores to pd.DataFrame
assert total_influence.shape[0] == len(src_ids)
assert total_influence.shape[1] == len(tgt_ids)
# Ensure src_ids and tgt_ids are in the DataFrame's index and columns, respectively
# Ensure src_ids and tgt_ids are in the DataFrame's index and columns
self.influence_scores = self.influence_scores.reindex(
index=self.influence_scores.index.union(src_ids),
columns=self.influence_scores.columns.union(tgt_ids),
)

# Assign total_influence values to the corresponding locations in influence_scores
# Assign total_influence values to the corresponding locations
src_indices = [
self.influence_scores.index.get_loc(src_id) for src_id in src_ids
]
tgt_indices = [
self.influence_scores.columns.get_loc(tgt_id) for tgt_id in tgt_ids
]

self.influence_scores.iloc[
src_indices, tgt_indices
] = total_influence.cpu().numpy()
self.influence_scores.iloc[src_indices, tgt_indices] = total_influence.numpy()

return total_influence

@torch.no_grad()
def compute_self_influence(self, src_log, precondition=True, damping=None):
def compute_self_influence(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
precondition: Optional[bool] = True,
damping: Optional[float] = None,
):
"""
Compute self-influence scores. This can be used for uncertainty estimation.
Args:
src_log (Dict[str, Dict[str, torch.Tensor]]): Log of source gradients
precondition (Optional[bool], optional): Whether to precondition the gradients. Defaults to True.
damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None.
"""
src = src_log[1]
preconditioned_src = None
if precondition:
pc_src_log = self.precondition(src_log, damping)
pc_src, src = pc_src_log[1], src_log[1]
preconditioned_src = self.precondition(src_log, damping)[1]

# Compute self-influence scores
total_influence = 0.0
for module_name in pc_src.keys():
pc_src_module = pc_src[module_name]["grad"]
src_module = src[module_name]["grad"]
module_influence = reduce(pc_src_module * src_module, "n a b -> n", "sum")
total_influence += module_influence.squeeze()
return total_influence

def compute_influence_all(self, src_log, loader, precondition=True, damping=None):
self_influence_scores = self.norm(src, preconditioned_src)

return self_influence_scores

def compute_influence_all(
self,
src_log: Dict[str, Dict[str, torch.Tensor]],
loader: torch.utils.data.DataLoader,
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
damping: Optional[float] = None,
):
"""
Compute influence scores against all train data in the log. This can be used
for training data attribution.
Args:
src_log (Dict[str, Dict[str, torch.Tensor]]): Log of source gradients
loader (torch.utils.data.DataLoader): DataLoader of train data
mode (Optional[str], optional): Influence function mode. Defaults to "dot".
precondition (Optional[bool], optional): Whether to precondition the gradients. Defaults to True.
damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None.
"""
if precondition:
src_log = self.precondition(src_log, damping)

if_scores = []
if_scores_total = []
for tgt_log in loader:
if_scores.append(
self.compute_influence(src_log, tgt_log, precondition=False)
if_scores = self.compute_influence(
src_log, tgt_log, mode=mode, precondition=False
)
return torch.cat(if_scores, dim=-1)
if_scores_total.append(if_scores)
return torch.cat(if_scores_total, dim=-1)

def get_influence_scores(self):
"""
Return influence scores as a pd.DataFrame.
"""
return self.influence_scores

def save_influence_scores(self, filename="influence_scores.csv"):
"""
Save influence scores as a csv file.
Args:
filename (str, optional): save filename. Defaults to "influence_scores.csv".
"""
self.influence_scores.to_csv(filename, index=True, header=True)
get_logger().info(f"Influence scores saved to {filename}")

0 comments on commit 0e1e96c

Please sign in to comment.