From 6e3838daad7d4bd18541f2c27c754787eb34af4c Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Mon, 25 Nov 2024 13:28:34 +0800 Subject: [PATCH] AttackModel support `timm/` and `tv/` prefix --- README.md | 6 ++++++ torchattack/attack_model.py | 25 ++++++++++++++++++++----- torchattack/eval/runner.py | 15 ++++----------- torchattack/pna_patchout.py | 13 ++----------- torchattack/tgr.py | 12 +++--------- torchattack/vdc.py | 5 ++--- 6 files changed, 37 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 4333242..3cd4ca1 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,12 @@ from torchattack import AttackModel model = AttackModel.from_pretrained(model_name='resnet50', device=device) # `AttackModel` automatically attach the model's `transform` and `normalize` functions transform, normalize = model.transform, model.normalize + +# Additionally, to explicitly specify where to load the pretrained model from (timm or torchvision), +# prepend the model name with 'timm/' or 'tv/' respectively, or use the `from_timm` argument, e.g. +vit_b16 = AttackModel.from_pretrained(model_name='timm/vit_base_patch16_224', device=device) +inception_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device) +pit_b = AttackModel.from_pretrained(model_name='pit_b_224', device=device, from_timm=True) ``` Initialize an attack by importing its attack class. diff --git a/torchattack/attack_model.py b/torchattack/attack_model.py index 08f7887..9cd3806 100644 --- a/torchattack/attack_model.py +++ b/torchattack/attack_model.py @@ -63,9 +63,12 @@ def from_pretrained( Loads a pretrained model and initializes an AttackModel instance. Args: - model_name: The name of the model to load. + model_name: The name of the model to load. Accept specifying the model from + `timm` or `torchvision.models` by prefixing the model name with `timm/` + or `tv/`. Takes precedence over the `from_timm` flag. device: The device on which to load the model. - from_timm: Whether to load the model from timm. Defaults to False. + from_timm: Explicitly specifying to load the model from timm or torchvision. + Priority lower than argument `model_name`. Defaults to False. Returns: AttackModel: An instance of AttackModel initialized with pretrained model. @@ -73,6 +76,14 @@ def from_pretrained( import torchvision.transforms as t + # Accept `timm/` or `tv/` as model_name, + # which takes precedence over the `from_timm` flag. + if model_name.startswith('timm/'): + model_name, from_timm = model_name[5:], True + elif model_name.startswith('tv/'): + model_name, from_timm = model_name[3:], False + + # Load the model from timm if specified if from_timm: import timm @@ -122,9 +133,13 @@ def from_pretrained( return cls(model_name, device, model, transform, normalize) except ValueError: - print( - f'Warning: Model `{model_name}` not found in torchvision.models, ' - 'falling back to loading weights from timm.' + from warnings import warn + + warn( + f'model `{model_name}` not found in torchvision.models, ' + 'falling back to loading weights from timm.', + category=UserWarning, + stacklevel=2, ) return cls.from_pretrained(model_name, device, from_timm=True) diff --git a/torchattack/eval/runner.py b/torchattack/eval/runner.py index 3850bf0..dbce171 100644 --- a/torchattack/eval/runner.py +++ b/torchattack/eval/runner.py @@ -9,7 +9,6 @@ def run_attack( dataset_root: str = 'datasets/nips2017', max_samples: int = 100, batch_size: int = 4, - from_timm: bool = False, ) -> None: """Helper function to run attacks in `__main__`. @@ -26,7 +25,6 @@ def run_attack( dataset_root: Root directory of the dataset. Defaults to "datasets/nips2017". max_samples: Max number of samples to attack. Defaults to 100. batch_size: Batch size for the dataloader. Defaults to 16. - from_timm: Use timm to load the model. Defaults to True. """ import torch @@ -42,7 +40,7 @@ def run_attack( # Setup model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = AttackModel.from_pretrained(model_name, device, from_timm) + model = AttackModel.from_pretrained(model_name, device) transform, normalize = model.transform, model.normalize # Set up dataloader @@ -70,10 +68,7 @@ def run_attack( # Setup victim models if provided if victim_model_names: - victim_models = [ - AttackModel.from_pretrained(name, device, from_timm) - for name in victim_model_names - ] + victims = [AttackModel.from_pretrained(n, device) for n in victim_model_names] victim_frms = [FoolingRateMetric() for _ in victim_model_names] # Run attack over the dataset (100 images by default) @@ -98,7 +93,7 @@ def run_attack( # Track transfer fooling rates if victim models are provided if victim_model_names: - for _, (vmodel, vfrm) in enumerate(zip(victim_models, victim_frms)): + for _, (vmodel, vfrm) in enumerate(zip(victims, victim_frms)): v_cln_outs = vmodel(vmodel.normalize(x)) v_adv_outs = vmodel(vmodel.normalize(advs)) vfrm.update(y, v_cln_outs, v_adv_outs) @@ -108,7 +103,7 @@ def run_attack( print(f'Surrogate ({model_name}): {cln_acc=:.2%}, {adv_acc=:.2%} ({fr=:.2%})') if victim_model_names: - for vmodel, vfrm in zip(victim_models, victim_frms): + for vmodel, vfrm in zip(victims, victim_frms): vcln_acc, vadv_acc, vfr = vfrm.compute() print( f'Victim ({vmodel.model_name}): cln_acc={vcln_acc:.2%}, ' @@ -127,7 +122,6 @@ def run_attack( parser.add_argument('--dataset-root', type=str, default='datasets/nips2017') parser.add_argument('--max-samples', type=int, default=100) parser.add_argument('--batch-size', type=int, default=4) - parser.add_argument('--from-timm', action='store_true') args = parser.parse_args() run_attack( @@ -138,5 +132,4 @@ def run_attack( dataset_root=args.dataset_root, max_samples=args.max_samples, batch_size=args.batch_size, - from_timm=args.from_timm, ) diff --git a/torchattack/pna_patchout.py b/torchattack/pna_patchout.py index 56a8fd9..1da7b19 100644 --- a/torchattack/pna_patchout.py +++ b/torchattack/pna_patchout.py @@ -79,7 +79,6 @@ def __init__( # Register hooks if self.pna_skip: - self.hooks: list[torch.utils.hooks.RemovableHandle] = [] self._register_vit_model_hook() # Set default image size and number of patches for PatchOut @@ -137,9 +136,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: delta.grad.detach_() delta.grad.zero_() - for hook in self.hooks: - hook.remove() - return x + delta def _register_vit_model_hook(self): @@ -169,8 +165,7 @@ def attn_drop_mask_grad( # Register backward hook for layers specified in supported_vit_cfg for layer in supported_vit_cfg[self.model_name]: module = rgetattr(self.model, layer) - hook = module.register_backward_hook(drop_hook_func) - self.hooks.append(hook) + module.register_backward_hook(drop_hook_func) def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor: delta_mask = torch.zeros_like(delta) @@ -196,8 +191,4 @@ def _apply_patch_out(self, delta: torch.Tensor, seed: int) -> torch.Tensor: if __name__ == '__main__': from torchattack.eval import run_attack - run_attack( - PNAPatchOut, - model_name='vit_base_patch16_224', - from_timm=True, - ) + run_attack(PNAPatchOut, model_name='timm/vit_base_patch16_224') diff --git a/torchattack/tgr.py b/torchattack/tgr.py index 3d78521..09f1cc4 100644 --- a/torchattack/tgr.py +++ b/torchattack/tgr.py @@ -74,7 +74,6 @@ def __init__( self.lossfn = nn.CrossEntropyLoss() # Register hooks - self.hooks: list[torch.utils.hooks.RemovableHandle] = [] self._register_tgr_model_hooks() def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -124,9 +123,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: delta.grad.detach_() delta.grad.zero_() - for hook in self.hooks: - hook.remove() - return x + delta def _register_tgr_model_hooks(self): @@ -346,8 +342,7 @@ def mlp_tgr( for hook_func, layers in supported_vit_cfg[self.model_name]: for layer in layers: module = rgetattr(self.model, layer) - hook = module.register_backward_hook(hook_func) - self.hooks.append(hook) + module.register_backward_hook(hook_func) if __name__ == '__main__': @@ -355,7 +350,6 @@ def mlp_tgr( run_attack( TGR, - model_name='vit_base_patch16_224', - victim_model_names=['cait_s24_224', 'visformer_small'], - from_timm=True, + model_name='timm/vit_base_patch16_224', + victim_model_names=['timm/cait_s24_224', 'timm/visformer_small'], ) diff --git a/torchattack/vdc.py b/torchattack/vdc.py index 4c0ebf7..7fcabdf 100644 --- a/torchattack/vdc.py +++ b/torchattack/vdc.py @@ -709,7 +709,6 @@ def attn_add_vis( run_attack( VDC, - model_name='pit_b_224', - victim_model_names=['cait_s24_224', 'visformer_small'], - from_timm=True, + model_name='timm/pit_b_224', + victim_model_names=['timm/cait_s24_224', 'timm/visformer_small'], )