Skip to content

Commit

Permalink
clean influence function code
Browse files Browse the repository at this point in the history
  • Loading branch information
Sang Choe authored and Sang Choe committed Apr 23, 2024
1 parent 4410152 commit cfe3c1b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 69 deletions.
113 changes: 45 additions & 68 deletions logix/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
precondition_kfac,
precondition_raw,
cross_dot_product,
merge_influence_results,
)
from logix.statistic.utils import make_2d

Expand Down Expand Up @@ -99,12 +100,10 @@ def compute_influence(
src_ids, src = src_log
tgt_ids, tgt = tgt_log

# Initialize influence
total_influence = 0
if influence_groups is not None:
total_influence = {"total": 0}
for influence_group in influence_groups:
total_influence[influence_group] = 0
# Initialize influence scores
total_influence = {"total": 0}
for influence_group in influence_groups or []:
total_influence[influence_group] = 0

# Compute influence scores. By default, we should compute the basic influence
# scores, which is essentially the inner product between the source and target
Expand All @@ -123,12 +122,10 @@ def compute_influence(
module_influence = cross_dot_product(
src[module_name]["grad"], tgt[module_name]["grad"]
)
if influence_groups is None:
total_influence += module_influence
else:
total_influence["total"] += module_influence
in_groups = [ig for ig in influence_groups if ig in module_name]
for group in in_groups:
total_influence["total"] += module_influence
if influence_groups is not None:
groups = [g for g in influence_groups if g in module_name]
for group in groups:
total_influence[group] += module_influence

if mode == "cosine":
Expand All @@ -138,40 +135,35 @@ def compute_influence(
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)
if influence_groups is None:
total_influence /= torch.sqrt(tgt_norm.unsqueeze(0))
else:
for key in total_influence.keys():
total_influence[key] /= tgt_norm[key]
).pop("influence")
for key in total_influence.keys():
tgt_norm_key = tgt_norm if influence_groups is None else tgt_norm[key]
total_influence[key] /= torch.sqrt(tgt_norm_key.unsqueeze(0))
elif mode == "l2":
tgt_norm = self.compute_self_influence(
tgt_log,
precondition=True,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)
if influence_groups is None:
total_influence -= 0.5 * tgt_norm.unsqueeze(0)
else:
for key in total_influence.keys():
total_influence[key] -= 0.5 * tgt_norm[key].unsqueeze(0)
).pop("influence")
for key in total_influence.keys():
tgt_norm_key = tgt_norm if influence_groups is None else tgt_norm[key]
total_influence[key] -= 0.5 * tgt_norm_key.unsqueeze(0)

# Move influence scores to CPU to save memory
if influence_groups is None:
assert total_influence.shape[0] == len(src_ids)
assert total_influence.shape[1] == len(tgt_ids)
total_influence = total_influence.cpu()
else:
for key, value in total_influence.items():
assert value.shape[0] == len(src_ids)
assert value.shape[1] == len(tgt_ids)
total_influence[key] = value.cpu()
for key, value in total_influence.items():
assert value.shape[0] == len(src_ids)
assert value.shape[1] == len(tgt_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["tgt_ids"] = tgt_ids
result["influence"] = total_influence
result["influence"] = (
total_influence.pop("total")
if influence_groups is None
else total_influence
)

return result

Expand Down Expand Up @@ -205,12 +197,10 @@ def compute_self_influence(
if precondition:
tgt = self.precondition(src_log, hessian=hessian, damping=damping)[1]

# Initialize influence
total_influence = 0
if influence_groups is not None:
total_influence = {"total": 0}
for influence_group in influence_groups:
total_influence[influence_group] = 0
# Initialize influence scores
tota_influence = {"total": 0}
for influence_group in influence_groups or []:
total_influence[influence_group] = 0

# Compute self-influence scores
for module_name in src.keys():
Expand All @@ -219,25 +209,23 @@ def compute_self_influence(
module_influence = reduce(
src_module * tgt_module, "n a b -> n", "sum"
).reshape(-1)
if influence_groups is None:
total_influence += module_influence
else:
total_influence["total"] += module_influence
in_groups = [ig for ig in influence_groups if ig in module_name]
for group in in_groups:
total_influence["total"] += module_influence
if influence_groups is not None:
groups = [g for g in influence_groups if g in module_name]
for group in groups:
total_influence[group] += module_influence

# Move influence scores to CPU to save memory
if influence_groups is not None:
assert len(total_influence) == len(src_ids)
total_influence = total_influence.cpu()
else:
for key, value in total_influence.items():
assert len(value) == len(src_ids)
total_influence[key] = value.cpu()
for key, value in total_influence.items():
assert len(value) == len(src_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["influence"] = total_influence
result["influence"] = (
total_influence.pop("total")
if influence_groups is None
else total_influence
)

return result

Expand Down Expand Up @@ -268,29 +256,18 @@ def compute_influence_all(
result_all = None
for tgt_log in tqdm(loader, desc="Compute IF"):
result = self.compute_influence(
src_log,
tgt_log,
src_log=src_log,
tgt_log=tgt_log,
mode=mode,
precondition=False,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)

# Merge results
if result_all is None:
result_all = result
continue
result_all["tgt_ids"].extend(result["tgt_ids"])
if influence_groups is None:
result_all["influence"] = torch.cat(
[result_all["influence"], result["influence"]], dim=1
)
else:
for key in result_all["influence"].keys():
result_all["influence"][key] = torch.cat(
[result_all["influence"][key], result["influence"][key]],
dim=1,
)
merge_influence_results(result_all, result)

return result_all
16 changes: 15 additions & 1 deletion logix/analysis/influence_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def precondition_raw(
return preconditioned


def cross_dot_product(src: torch.Tensor, tgt: torch.Tensor):
def cross_dot_product(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
assert src.shape[1:] == tgt.shape[1:]
src_expanded = rearrange(src, "n ... -> n 1 ...")
tgt_expanded = rearrange(tgt, "m ... -> 1 m ...")
Expand All @@ -82,3 +82,17 @@ def cross_dot_product(src: torch.Tensor, tgt: torch.Tensor):
)

return dot_product_result


def merge_influence_results(result_all, result) -> None:
result_all["tgt_ids"].extend(result["tgt_ids"])
if isinstance(result["influence"], dict):
for key in result_all["influence"].keys():
result_all["influence"][key] = torch.cat(
[result_all["influence"][key], result["influence"][key]], dim=1
)
else:
assert isinstance(result["influence"], torch.Tensor)
result_all["influence"] = torch.cat(
[result_all["influence"], result["influence"]], dim=1
)

0 comments on commit cfe3c1b

Please sign in to comment.