Skip to content

Commit

Permalink
Extend tests for composable influence
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Jun 5, 2024
1 parent eb7f7ed commit dff2d40
Showing 1 changed file with 139 additions and 30 deletions.
169 changes: 139 additions & 30 deletions tests/influence/torch/test_influence_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Callable, NamedTuple, Tuple
from typing import Callable, Dict, NamedTuple, OrderedDict, Tuple

import numpy as np
import pytest
Expand All @@ -25,6 +25,7 @@
PreConditioner,
)
from pydvl.influence.torch.util import BlockMode, SecondOrderMode
from pydvl.influence.types import UnsupportedInfluenceModeException
from tests.influence.torch.conftest import minimal_training

torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -246,46 +247,112 @@ def model_and_data(


@fixture
def direct_influence_function_model(model_and_data, test_case: TestCase):
def block_structure(request):
return getattr(request, "param", BlockMode.FULL)


@fixture
def second_order_mode(request):
return getattr(request, "param", SecondOrderMode.HESSIAN)


@fixture
def direct_influence_function_model(
model_and_data, test_case: TestCase, block_structure: BlockMode, second_order_mode
):
model, loss, x_train, y_train, x_test, y_test = model_and_data
train_dataloader = DataLoader(
TensorDataset(x_train, y_train), batch_size=test_case.batch_size
)
return DirectInfluence(model, loss, test_case.hessian_reg).fit(train_dataloader)
return DirectInfluence(
model,
loss,
test_case.hessian_reg,
block_structure=block_structure,
second_order_mode=second_order_mode,
).fit(train_dataloader)


@fixture
def direct_influences(
direct_influence_function_model: DirectInfluence,
model_and_data,
test_case: TestCase,
) -> NDArray:
) -> torch.Tensor:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influences(
x_test, y_test, x_train, y_train, mode=test_case.mode
).numpy()
)


@fixture
def direct_influences_by_block(
direct_influence_function_model: DirectInfluence,
model_and_data,
test_case: TestCase,
) -> Dict[str, torch.Tensor]:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influences_by_block(
x_test, y_test, x_train, y_train, mode=test_case.mode
)


@fixture
def direct_sym_influences(
direct_influence_function_model: DirectInfluence,
model_and_data,
test_case: TestCase,
) -> NDArray:
) -> torch.Tensor:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influences(
x_train, y_train, mode=test_case.mode
).numpy()
)


@fixture
def direct_factors(
direct_influence_function_model: DirectInfluence,
model_and_data,
test_case: TestCase,
) -> NDArray:
) -> torch.Tensor:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influence_factors(x_train, y_train)


@fixture
def direct_factors_by_block(
direct_influence_function_model: DirectInfluence,
model_and_data,
test_case: TestCase,
) -> OrderedDict[str, torch.Tensor]:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influence_factors_by_block(x_train, y_train)


@fixture
def direct_influences_from_factors_by_block(
direct_influence_function_model: DirectInfluence,
direct_factors_by_block,
model_and_data,
test_case: TestCase,
) -> Dict[str, torch.Tensor]:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influence_factors(x_train, y_train).numpy()
return direct_influence_function_model.influences_from_factors_by_block(
direct_factors_by_block, x_train, y_train, mode=test_case.mode
)


@fixture
def direct_influences_from_factors(
direct_influence_function_model: DirectInfluence,
direct_factors,
model_and_data,
test_case: TestCase,
) -> torch.Tensor:
model, loss, x_train, y_train, x_test, y_test = model_and_data
return direct_influence_function_model.influences_from_factors(
direct_factors, x_train, y_train, mode=test_case.mode
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -668,10 +735,10 @@ def test_influences_ekfac(
assert np.allclose(ekfac_influence_values, influence_from_factors)
assert np.allclose(ekfac_influence_values, accumulated_inf_by_layer)
check_influence_correlations(
direct_influences, ekfac_influence_values, threshold=0.94
direct_influences.numpy(), ekfac_influence_values, threshold=0.94
)
check_influence_correlations(
direct_sym_influences, ekfac_self_influence, threshold=0.94
direct_sym_influences.numpy(), ekfac_self_influence, threshold=0.94
)


Expand Down Expand Up @@ -766,8 +833,13 @@ def test_influences_cg(
]


def check_correlation(arr_1, arr_2, corr_val):
assert np.all(pearsonr(arr_1, arr_2).statistic > corr_val)
assert np.all(spearmanr(arr_1, arr_2).statistic > corr_val)


@pytest.mark.parametrize("composable_influence_factory", composable_influence_factories)
@pytest.mark.parametrize("block_mode", [mode for mode in BlockMode])
@pytest.mark.parametrize("block_structure", [mode for mode in BlockMode], indirect=True)
@pytest.mark.torch
def test_composable_influence(
test_case: TestCase,
Expand All @@ -781,8 +853,12 @@ def test_composable_influence(
],
direct_influences,
direct_sym_influences,
direct_factors,
direct_influences_by_block,
direct_factors_by_block,
direct_influences_from_factors_by_block,
device: torch.device,
block_mode,
block_structure,
composable_influence_factory,
):
model, loss, x_train, y_train, x_test, y_test = model_and_data
Expand All @@ -791,25 +867,58 @@ def test_composable_influence(
TensorDataset(x_train, y_train), batch_size=test_case.batch_size
)

harmonic_mean_influence = composable_influence_factory(
model, loss, test_case.hessian_reg, block_structure=block_mode
infl_model = composable_influence_factory(
model, loss, test_case.hessian_reg, block_structure=block_structure
).to(device)
harmonic_mean_influence = harmonic_mean_influence.fit(train_dataloader)
harmonic_mean_influence_values = (
harmonic_mean_influence.influences(
x_test, y_test, x_train, y_train, mode=test_case.mode
)
.cpu()
.numpy()

with pytest.raises(NotFittedException):
infl_model.influences(x_test, y_test, x_train, y_train, mode=test_case.mode)

infl_model = infl_model.fit(train_dataloader)

with pytest.raises(UnsupportedInfluenceModeException):
infl_model.influences(x_test, y_test, mode="strange_mode")

infl_values = infl_model.influences(
x_test, y_test, x_train, y_train, mode=test_case.mode
)

threshold = 0.999
flat_direct_influences = direct_influences.reshape(-1)
flat_harmonic_influences = harmonic_mean_influence_values.reshape(-1)
assert np.all(
pearsonr(flat_direct_influences, flat_harmonic_influences).statistic > threshold
threshold = 1 - 1e-3
check_correlation(
direct_influences.reshape(-1), infl_values.reshape(-1), corr_val=threshold
)
assert np.all(
spearmanr(flat_direct_influences, flat_harmonic_influences).statistic
> threshold

sym_infl_values = infl_model.influences(x_train, y_train, mode=test_case.mode)
check_correlation(
direct_sym_influences.reshape(-1), sym_infl_values.reshape(-1), threshold
)

infl_factors = infl_model.influence_factors(x_train, y_train)
flat_factors = infl_factors.reshape(-1)
flat_direct_factors = direct_factors.reshape(-1)
check_correlation(flat_factors, flat_direct_factors, threshold)

infl_factors_by_block = infl_model.influence_factors_by_block(x_train, y_train)
for block, infl in infl_factors_by_block.items():
check_correlation(
infl.reshape(-1), direct_factors_by_block[block].reshape(-1), threshold
)

infl_by_block = infl_model.influences_by_block(
x_test, y_test, x_train, y_train, mode=test_case.mode
)
for block, infl in infl_by_block.items():
check_correlation(
infl.reshape(-1), direct_influences_by_block[block].reshape(-1), threshold
)

infl_from_factors_by_block = infl_model.influences_from_factors_by_block(
infl_factors_by_block, x_train, y_train, mode=test_case.mode
)

for block, infl in infl_from_factors_by_block.items():
check_correlation(
infl.reshape(-1),
direct_influences_from_factors_by_block[block].reshape(-1),
threshold,
)

0 comments on commit dff2d40

Please sign in to comment.