Skip to content

Commit

Permalink
AttackModel support timm/ and tv/ prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Nov 25, 2024
1 parent a71df81 commit 6e3838d
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 39 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ from torchattack import AttackModel
model = AttackModel.from_pretrained(model_name='resnet50', device=device)
# `AttackModel` automatically attach the model's `transform` and `normalize` functions
transform, normalize = model.transform, model.normalize

# Additionally, to explicitly specify where to load the pretrained model from (timm or torchvision),
# prepend the model name with 'timm/' or 'tv/' respectively, or use the `from_timm` argument, e.g.
vit_b16 = AttackModel.from_pretrained(model_name='timm/vit_base_patch16_224', device=device)
inception_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device)
pit_b = AttackModel.from_pretrained(model_name='pit_b_224', device=device, from_timm=True)
```

Initialize an attack by importing its attack class.
Expand Down
25 changes: 20 additions & 5 deletions torchattack/attack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,27 @@ def from_pretrained(
Loads a pretrained model and initializes an AttackModel instance.
Args:
model_name: The name of the model to load.
model_name: The name of the model to load. Accept specifying the model from
`timm` or `torchvision.models` by prefixing the model name with `timm/`
or `tv/`. Takes precedence over the `from_timm` flag.
device: The device on which to load the model.
from_timm: Whether to load the model from timm. Defaults to False.
from_timm: Explicitly specifying to load the model from timm or torchvision.
Priority lower than argument `model_name`. Defaults to False.
Returns:
AttackModel: An instance of AttackModel initialized with pretrained model.
"""

import torchvision.transforms as t

# Accept `timm/<model_name>` or `tv/<model_name>` as model_name,
# which takes precedence over the `from_timm` flag.
if model_name.startswith('timm/'):
model_name, from_timm = model_name[5:], True
elif model_name.startswith('tv/'):
model_name, from_timm = model_name[3:], False

# Load the model from timm if specified
if from_timm:
import timm

Expand Down Expand Up @@ -122,9 +133,13 @@ def from_pretrained(
return cls(model_name, device, model, transform, normalize)

except ValueError:
print(
f'Warning: Model `{model_name}` not found in torchvision.models, '
'falling back to loading weights from timm.'
from warnings import warn

warn(
f'model `{model_name}` not found in torchvision.models, '
'falling back to loading weights from timm.',
category=UserWarning,
stacklevel=2,
)
return cls.from_pretrained(model_name, device, from_timm=True)

Expand Down
15 changes: 4 additions & 11 deletions torchattack/eval/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def run_attack(
dataset_root: str = 'datasets/nips2017',
max_samples: int = 100,
batch_size: int = 4,
from_timm: bool = False,
) -> None:
"""Helper function to run attacks in `__main__`.
Expand All @@ -26,7 +25,6 @@ def run_attack(
dataset_root: Root directory of the dataset. Defaults to "datasets/nips2017".
max_samples: Max number of samples to attack. Defaults to 100.
batch_size: Batch size for the dataloader. Defaults to 16.
from_timm: Use timm to load the model. Defaults to True.
"""

import torch
Expand All @@ -42,7 +40,7 @@ def run_attack(

# Setup model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttackModel.from_pretrained(model_name, device, from_timm)
model = AttackModel.from_pretrained(model_name, device)
transform, normalize = model.transform, model.normalize

# Set up dataloader
Expand Down Expand Up @@ -70,10 +68,7 @@ def run_attack(

# Setup victim models if provided
if victim_model_names:
victim_models = [
AttackModel.from_pretrained(name, device, from_timm)
for name in victim_model_names
]
victims = [AttackModel.from_pretrained(n, device) for n in victim_model_names]
victim_frms = [FoolingRateMetric() for _ in victim_model_names]

# Run attack over the dataset (100 images by default)
Expand All @@ -98,7 +93,7 @@ def run_attack(

# Track transfer fooling rates if victim models are provided
if victim_model_names:
for _, (vmodel, vfrm) in enumerate(zip(victim_models, victim_frms)):
for _, (vmodel, vfrm) in enumerate(zip(victims, victim_frms)):
v_cln_outs = vmodel(vmodel.normalize(x))
v_adv_outs = vmodel(vmodel.normalize(advs))
vfrm.update(y, v_cln_outs, v_adv_outs)
Expand All @@ -108,7 +103,7 @@ def run_attack(
print(f'Surrogate ({model_name}): {cln_acc=:.2%}, {adv_acc=:.2%} ({fr=:.2%})')

if victim_model_names:
for vmodel, vfrm in zip(victim_models, victim_frms):
for vmodel, vfrm in zip(victims, victim_frms):
vcln_acc, vadv_acc, vfr = vfrm.compute()
print(
f'Victim ({vmodel.model_name}): cln_acc={vcln_acc:.2%}, '
Expand All @@ -127,7 +122,6 @@ def run_attack(
parser.add_argument('--dataset-root', type=str, default='datasets/nips2017')
parser.add_argument('--max-samples', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--from-timm', action='store_true')
args = parser.parse_args()

run_attack(
Expand All @@ -138,5 +132,4 @@ def run_attack(
dataset_root=args.dataset_root,
max_samples=args.max_samples,
batch_size=args.batch_size,
from_timm=args.from_timm,
)
13 changes: 2 additions & 11 deletions torchattack/pna_patchout.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(

# Register hooks
if self.pna_skip:
self.hooks: list[torch.utils.hooks.RemovableHandle] = []
self._register_vit_model_hook()

# Set default image size and number of patches for PatchOut
Expand Down Expand Up @@ -137,9 +136,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta.grad.detach_()
delta.grad.zero_()

for hook in self.hooks:
hook.remove()

return x + delta

def _register_vit_model_hook(self):
Expand Down Expand Up @@ -169,8 +165,7 @@ def attn_drop_mask_grad(
# Register backward hook for layers specified in supported_vit_cfg
for layer in supported_vit_cfg[self.model_name]:
module = rgetattr(self.model, layer)
hook = module.register_backward_hook(drop_hook_func)
self.hooks.append(hook)
module.register_backward_hook(drop_hook_func)

def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor:
delta_mask = torch.zeros_like(delta)
Expand All @@ -196,8 +191,4 @@ def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor:
if __name__ == '__main__':
from torchattack.eval import run_attack

run_attack(
PNAPatchOut,
model_name='vit_base_patch16_224',
from_timm=True,
)
run_attack(PNAPatchOut, model_name='timm/vit_base_patch16_224')
12 changes: 3 additions & 9 deletions torchattack/tgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(
self.lossfn = nn.CrossEntropyLoss()

# Register hooks
self.hooks: list[torch.utils.hooks.RemovableHandle] = []
self._register_tgr_model_hooks()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -124,9 +123,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
delta.grad.detach_()
delta.grad.zero_()

for hook in self.hooks:
hook.remove()

return x + delta

def _register_tgr_model_hooks(self):
Expand Down Expand Up @@ -346,16 +342,14 @@ def mlp_tgr(
for hook_func, layers in supported_vit_cfg[self.model_name]:
for layer in layers:
module = rgetattr(self.model, layer)
hook = module.register_backward_hook(hook_func)
self.hooks.append(hook)
module.register_backward_hook(hook_func)


if __name__ == '__main__':
from torchattack.eval import run_attack

run_attack(
TGR,
model_name='vit_base_patch16_224',
victim_model_names=['cait_s24_224', 'visformer_small'],
from_timm=True,
model_name='timm/vit_base_patch16_224',
victim_model_names=['timm/cait_s24_224', 'timm/visformer_small'],
)
5 changes: 2 additions & 3 deletions torchattack/vdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,6 @@ def attn_add_vis(

run_attack(
VDC,
model_name='pit_b_224',
victim_model_names=['cait_s24_224', 'visformer_small'],
from_timm=True,
model_name='timm/pit_b_224',
victim_model_names=['timm/cait_s24_224', 'timm/visformer_small'],
)

0 comments on commit 6e3838d

Please sign in to comment.