From f9b9b57f2dadd2c0fcfb744b7478a896fc0cc7de Mon Sep 17 00:00:00 2001 From: Andy-Ko-0620 Date: Sat, 11 May 2024 22:19:20 +0800 Subject: [PATCH] add memory-efficient attention --- eval_knn.py | 5 +++-- main_dino.py | 2 +- vision_transformer.py | 38 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/eval_knn.py b/eval_knn.py index fe99a2604..626340178 100644 --- a/eval_knn.py +++ b/eval_knn.py @@ -209,8 +209,9 @@ def __getitem__(self, idx): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) + parser.add_argument('--num_classes', default=1000, type=int, help='Number of classes in dataset.') args = parser.parse_args() utils.init_distributed_mode(args) @@ -237,6 +238,6 @@ def __getitem__(self, idx): print("Features are ready!\nStart the k-NN classification.") for k in args.nb_knn: top1, top5 = knn_classifier(train_features, train_labels, - test_features, test_labels, k, args.temperature) + test_features, test_labels, k, args.temperature, num_classes=args.num_classes) print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") dist.barrier() diff --git a/main_dino.py b/main_dino.py index cade9873d..e895a18cc 100644 --- a/main_dino.py +++ b/main_dino.py @@ -125,7 +125,7 @@ def get_args_parser(): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") - parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") + parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.") return parser diff --git a/vision_transformer.py b/vision_transformer.py index f69a7ad05..1b1381870 100644 --- a/vision_transformer.py +++ b/vision_transformer.py @@ -20,9 +20,25 @@ import torch import torch.nn as nn - +import logging +import os +import warnings from utils import trunc_normal_ +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: @@ -90,14 +106,32 @@ def forward(self, x): x = self.proj(x) x = self.proj_drop(x) return x, attn + +class MemEffAttention(Attention): + def forward(self, x, attn_bias=None): + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = unbind(qkv, 2) + + attn = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = attn.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x, attn class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention( + self.attn = MemEffAttention( #Attention -> MemEffAttention dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim)