Skip to content

Commit

Permalink
v3.2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry24k committed Dec 9, 2021
1 parent 528ab80 commit 3d18299
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ Performance Comparison (CIFAR10).ipynb
utils.py
demos/models/*
demos/robustbench/
a.pt
15 changes: 15 additions & 0 deletions UPDATE_HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,19 @@







### v3.2.3

* `save`, `MultiAttack`: Now supports saving predictions.







2 changes: 1 addition & 1 deletion torchattacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
from .attacks.tifgsm import TIFGSM
from .attacks.jitter import Jitter

__version__ = '3.2.2'
__version__ = '3.2.3'
39 changes: 26 additions & 13 deletions torchattacks/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def set_training_mode(self, model_training=False, batchnorm_training=False, drop
self._batchnorm_training = batchnorm_training
self._dropout_training = dropout_training

def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
def save(self, data_loader, save_path=None, verbose=True, return_verbose=False, save_pred=False):
r"""
Save adversarial images as torch.tensor from given torch.utils.data.DataLoader.
Expand All @@ -155,6 +155,7 @@ def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
data_loader (torch.utils.data.DataLoader): data loader.
verbose (bool): True for displaying detailed information. (Default: True)
return_verbose (bool): True for returning detailed information. (Default: False)
save_pred (bool): True for saving predicted labels (Default: False)
"""
if (verbose==False) and (return_verbose==True):
Expand All @@ -163,6 +164,8 @@ def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
if save_path is not None:
image_list = []
label_list = []
if save_pred:
pre_list = []

correct = 0
total = 0
Expand All @@ -171,28 +174,23 @@ def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
total_batch = len(data_loader)

given_training = self.model.training
given_return_type = self._return_type
self._return_type = 'float'

for step, (images, labels) in enumerate(data_loader):
start = time.time()
adv_images = self.__call__(images, labels)

batch_size = len(images)

if save_path is not None:
image_list.append(adv_images.cpu())
label_list.append(labels.cpu())

if self._return_type == 'int':
adv_images = adv_images.float()/255

if verbose:
with torch.no_grad():
if given_training:
self.model.eval()
outputs = self.model(adv_images)
_, predicted = torch.max(outputs.data, 1)
_, pred = torch.max(outputs.data, 1)
total += labels.size(0)
right_idx = (predicted == labels.to(self.device))
right_idx = (pred == labels.to(self.device))
correct += right_idx.sum()
end = time.time()
delta = (adv_images - images.to(self.device)).view(batch_size, -1)
Expand All @@ -204,14 +202,29 @@ def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
elapsed_time = end-start
self._save_print(progress, rob_acc, l2, elapsed_time, end='\r')

if save_path is not None:
if given_return_type == 'int':
adv_images = self._to_uint(adv_images.detach().cpu())
image_list.append(adv_images)
else:
image_list.append(adv_images.detach().cpu())

label_list.append(labels.detach().cpu())
if save_pred:
pre_list.append(pred.detach().cpu())

# To avoid erasing the printed information.
if verbose:
self._save_print(progress, rob_acc, l2, elapsed_time, end='\n')

if save_path is not None:
x = torch.cat(image_list, 0)
y = torch.cat(label_list, 0)
torch.save((x, y), save_path)
image_list = torch.cat(image_list, 0)
label_list = torch.cat(label_list, 0)
if save_pred:
pre_list = torch.cat(pre_list, 0)
torch.save((image_list, label_list, pre_list), save_path)
else:
torch.save((image_list, label_list), save_path)
print('- Save complete!')

if given_training:
Expand Down
9 changes: 6 additions & 3 deletions torchattacks/attacks/multiattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _update_multi_atk_records(self, multi_atk_records):
for i, item in enumerate(multi_atk_records):
self._multi_atk_records[i] += item

def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
def save(self, data_loader, save_path=None, verbose=True, return_verbose=False, save_pred=False):
r"""
Overridden.
"""
Expand All @@ -106,10 +106,13 @@ def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
self._multi_atk_records.append(0.0)

if verbose:
rob_acc, l2, elapsed_time = super().save(data_loader, save_path, verbose=True, return_verbose=True)
rob_acc, l2, elapsed_time = super().save(data_loader, save_path,
verbose=True, return_verbose=True,
save_pred=save_pred)
sr = self._covert_to_success_rates(self._multi_atk_records)
else:
super().save(data_loader, save_path, verbose=False, return_verbose=False)
super().save(data_loader, save_path, verbose=False,
return_verbose=False, save_pred=save_pred)

self._clear_multi_atk_records()
self._accumulate_multi_atk_records = False
Expand Down

0 comments on commit 3d18299

Please sign in to comment.