diff --git a/autoattack/autoattack.py b/autoattack/autoattack.py index e633308..c8fdd4b 100644 --- a/autoattack/autoattack.py +++ b/autoattack/autoattack.py @@ -11,7 +11,7 @@ class AutoAttack(): def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, - attacks_to_run=[], version='standard', is_tf_model=False, + attacks_to_run=[], version='standard', eval_iter = None, is_tf_model=False, device='cuda', log_path=None): self.model = model self.norm = norm @@ -21,6 +21,7 @@ def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=True, self.verbose = verbose self.attacks_to_run = attacks_to_run self.version = version + self.eval_iter = eval_iter self.is_tf_model = is_tf_model self.device = device self.logger = Logger(log_path) @@ -108,9 +109,19 @@ def run_standard_evaluation(self, self.logger.log('{} was/were already run.'.format(', '.join(state.run_attacks))) # checks on type of defense + is_randomized_defense = True if self.version != 'rand': - checks.check_randomized(self.get_logits, x_orig[:bs].to(self.device), + is_randomized_defense = checks.is_randomized(self.get_logits, x_orig[:bs].to(self.device), y_orig[:bs].to(self.device), bs=bs, logger=self.logger) + + if self.eval_iter is None: + if is_randomized_defense: + self.logger.log("random defense, using default eval_iter 20") + self.eval_iter = 20 + else: + self.logger.log("non-random defense, using default eval_iter 1") + self.eval_iter = 1 + n_cls = checks.check_range_output(self.get_logits, x_orig[:bs].to(self.device), logger=self.logger) checks.check_dynamic(self.model, x_orig[:bs].to(self.device), self.is_tf_model, @@ -122,19 +133,23 @@ def run_standard_evaluation(self, # calculate accuracy n_batches = int(np.ceil(x_orig.shape[0] / bs)) if state.robust_flags is None: - robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device) + #robust_flags = torch.zeros(x_orig.shape[0], dtype=torch.bool, device=x_orig.device) + robust_flags = torch.zeros(x_orig.shape[0], device=x_orig.device) y_adv = torch.empty_like(y_orig) for batch_idx in range(n_batches): start_idx = batch_idx * bs - end_idx = min( (batch_idx + 1) * bs, x_orig.shape[0]) + end_idx = min((batch_idx + 1) * bs, x_orig.shape[0]) x = x_orig[start_idx:end_idx, :].clone().to(self.device) y = y_orig[start_idx:end_idx].clone().to(self.device) - output = self.get_logits(x).max(dim=1)[1] - y_adv[start_idx: end_idx] = output - correct_batch = y.eq(output) - robust_flags[start_idx:end_idx] = correct_batch.detach().to(robust_flags.device) + for _ in range(self.eval_iter): + output = self.get_logits(x).max(dim=1)[1] + y_adv[start_idx: end_idx] = output + correct_batch = y.eq(output) + robust_flags[start_idx:end_idx] += correct_batch.detach().to(robust_flags.device) + + robust_flags /= self.eval_iter state.robust_flags = robust_flags robust_accuracy = torch.sum(robust_flags).item() / x_orig.shape[0] robust_accuracy_dict = {'clean': robust_accuracy} @@ -154,7 +169,8 @@ def run_standard_evaluation(self, startt = time.time() for attack in attacks_to_run: # item() is super important as pytorch int division uses floor rounding - num_robust = torch.sum(robust_flags).item() + #num_robust = torch.sum(robust_flags).item() + num_robust = torch.sum(robust_flags != 0).item() if num_robust == 0: break @@ -218,17 +234,31 @@ def run_standard_evaluation(self, else: raise ValueError('Attack not supported') - output = self.get_logits(adv_curr).max(dim=1)[1] - false_batch = ~y.eq(output).to(robust_flags.device) - non_robust_lin_idcs = batch_datapoint_idcs[false_batch] - robust_flags[non_robust_lin_idcs] = False - state.robust_flags = robust_flags + # output = self.get_logits(adv_curr).max(dim=1)[1] + # false_batch = ~y.eq(output).to(robust_flags.device) + # non_robust_lin_idcs = batch_datapoint_idcs[false_batch] + # robust_flags[non_robust_lin_idcs] = False + # state.robust_flags = robust_flags - x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device) - y_adv[non_robust_lin_idcs] = output[false_batch].detach().to(x_adv.device) + # x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device) + # y_adv[non_robust_lin_idcs] = output[false_batch].detach().to(x_adv.device) + + correct_batch = torch.zeros_like(y) + for _ in range(self.eval_iter): + output = self.get_logits(adv_curr).max(dim=1)[1] + correct_batch += y.eq(output).to(robust_flags.device) + + correct_batch = correct_batch / self.eval_iter + + smaller_indices = correct_batch < robust_flags[batch_datapoint_idcs] + robust_flags[batch_datapoint_idcs[smaller_indices]] = correct_batch[smaller_indices] + x_adv[batch_datapoint_idcs[smaller_indices]] = adv_curr[smaller_indices].detach().to(x_adv.device) + y_adv[batch_datapoint_idcs[smaller_indices]] = output[smaller_indices].detach().to(x_adv.device) + if self.verbose: - num_non_robust_batch = torch.sum(false_batch) + #num_non_robust_batch = torch.sum(false_batch) + num_non_robust_batch = torch.sum(1 - correct_batch) self.logger.log('{} - {}/{} - {} out of {} successfully perturbed'.format( attack, batch_idx + 1, n_batches, num_non_robust_batch, x.shape[0])) diff --git a/autoattack/checks.py b/autoattack/checks.py index 964a479..2cb9927 100644 --- a/autoattack/checks.py +++ b/autoattack/checks.py @@ -15,7 +15,7 @@ checks_doc_path = 'flags_doc.md' -def check_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None): +def is_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None): acc = [] corrcl = [] outputs = [] @@ -39,6 +39,8 @@ def check_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None): warnings.warn(Warning(msg)) else: logger.log(f'Warning: {msg}') + return True + return False def check_range_output(model, x, alpha=1e-5, logger=None):