Skip to content

Commit

Permalink
add support for model training during attack, if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
simplymathematics committed Mar 7, 2024
1 parent 08c1340 commit 5ae9af6
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions deckard/base/attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from art.utils import to_categorical, compute_success
from sklearn.utils.validation import check_is_fitted
from sklearn.base import BaseEstimator
from sklearn.exceptions import NotFittedError
from random import randint
from ..data import Data
from ..model import Model
Expand Down Expand Up @@ -138,7 +141,6 @@ def __call__(
adv_probabilities_file=None,
adv_predictions_file=None,
adv_losses_file=None,
**kwargs,
):
time_dict = {}
results = {}
Expand All @@ -149,6 +151,11 @@ def __call__(
if attack_file is not None and Path(attack_file).exists():
samples = self.data.load(attack_file)
else:
print(f"Type of self.init: {type(self.init)}")
print(f"Type of self.init.model: {type(self.init.model)}")
print(f"Type of model: {type(model)}")


atk = self.init(model=model, attack_size=self.attack_size)

if targeted is True:
Expand Down Expand Up @@ -301,7 +308,6 @@ def __call__(
adv_probabilities_file=None,
adv_predictions_file=None,
adv_losses_file=None,
**kwargs,
):
time_dict = {}
results = {}
Expand Down Expand Up @@ -489,7 +495,6 @@ def __call__(
adv_probabilities_file=None,
adv_predictions_file=None,
adv_losses_file=None,
**kwargs,
):
data_shape = data[0][0].shape
time_dict = {}
Expand Down Expand Up @@ -609,7 +614,6 @@ def __call__(
adv_probabilities_file=None,
adv_predictions_file=None,
adv_losses_file=None,
**kwargs,
):
results = {}
time_dict = {}
Expand Down Expand Up @@ -805,8 +809,16 @@ def __call__(
**kwargs,
):
name = self.init.name
kwargs = deepcopy(self.kwargs)
kwargs.update({"init": self.init.kwargs})
data = self.data()
data, model = self.model.initialize(data)
if isinstance(model, BaseEstimator):
try:
check_is_fitted(model), "Model must be fitted before calling attack."
except NotFittedError as e:
logger.warning(f"Model not fitted. Fitting model before attack. Error: {e}")
model, _ = self.model.fit(data=data, model=model)
if "art" not in str(type(model)):
model = self.model.art(model=model, data=data)
if self.method == "evasion":
Expand All @@ -815,7 +827,7 @@ def __call__(
data=self.data,
model=self.model,
attack_size=self.attack_size,
**self.init.kwargs,
**kwargs,
)
result = attack(
data,
Expand All @@ -831,7 +843,7 @@ def __call__(
data=self.data,
model=self.model,
attack_size=self.attack_size,
**self.init.kwargs,
**kwargs,
)
result = attack(
data,
Expand All @@ -847,7 +859,7 @@ def __call__(
data=self.data,
model=self.model,
attack_size=self.attack_size,
**self.init.kwargs,
**kwargs,
)
result = attack(
data,
Expand All @@ -863,7 +875,7 @@ def __call__(
data=self.data,
model=self.model,
attack_size=self.attack_size,
**self.init.kwargs,
**kwargs,
)
result = attack(
data,
Expand Down

0 comments on commit 5ae9af6

Please sign in to comment.