From 00e1bccc88390b25dbe6a363c4b6e23567348e14 Mon Sep 17 00:00:00 2001 From: Mathilde Caron Date: Sun, 2 May 2021 22:10:03 +0200 Subject: [PATCH 1/9] logs for example runs --- README.md | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4d3ea3bfb..3a34d1660 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ python main_dino.py --help ``` ### Vanilla DINO training :sauropod: -Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach ~69.3% on k-NN eval and ~73.8% on linear eval. We will shortly provide [training](/to/do) and [linear evaluation](/to/do) logs for this run to help reproducibility. +Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and ~73.8% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](/to/do) logs for this run to help reproducibility. ``` python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir ``` @@ -133,14 +133,27 @@ python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 - -The resulting pretrained model should reach ~73.4% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We will shortly provide [training](/to/do) and [linear evaluation](/to/do) logs for this run to help reproducibility. +The resulting pretrained model should reach 73.3% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](/to/do) logs for this run to help reproducibility. ### ResNet-50 and other convnets trainings -This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example here is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs: +This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) logs for this run. ``` python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir ``` +## Self-attention visualization +You can look at the self-attention of the [CLS] token on the different heads of the last layer by running: +``` +python visualize_attention.py +``` + +Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for video inference. + +
+ Self-attention from a Vision Transformer with 8x8 patches trained with DINO +
+ + ## Evaluation: k-NN classification on ImageNet To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run: ``` @@ -157,15 +170,6 @@ To train a supervised linear classifier on frozen weights on a single node with python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet ``` -## Self-attention visualization -You can look at the self-attention of the [CLS] token on the different heads of the last layer by running: -``` -python visualize_attention.py -``` -
- Self-attention from a Vision Transformer with 8x8 patches trained with DINO -
- ## License See the [LICENSE](LICENSE) file for more details. From 8aa93fdc90eae4b183c4e3c005174a9f634ecfbf Mon Sep 17 00:00:00 2001 From: Mathilde Caron Date: Sun, 2 May 2021 22:13:13 +0200 Subject: [PATCH 2/9] logs for example runs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3a34d1660..97ce7197e 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 - -The resulting pretrained model should reach 73.3% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](/to/do) logs for this run to help reproducibility. +The resulting pretrained model should reach 73.3% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](/to/do) logs for this run to help reproducibility. ### ResNet-50 and other convnets trainings This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) logs for this run. From 21169953a60d3611241a1c74aaf6f8e477f305ee Mon Sep 17 00:00:00 2001 From: aquadzn Date: Sun, 2 May 2021 23:17:09 +0200 Subject: [PATCH 3/9] Added a video generation script and instructions to README --- README.md | 28 +++- video_generation.py | 303 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 video_generation.py diff --git a/README.md b/README.md index 97ce7197e..bed08eeda 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,33 @@ You can look at the self-attention of the [CLS] token on the different heads of python visualize_attention.py ``` -Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for video inference. +## Self-attention video generation +You can generate videos like the one on the blog post with `video_generation.py`. + +Extract frames from input video and generate attention video: +``` +python video_generation.py --input_path ../video.mp4 \ + --output_dir ../output/ \ + --resize 256 \ +``` + +Use folder of frames already extracted and attention video: +``` +python video_generation.py --input_path ../frames/ \ + --output_dir ../output/ \ + --resize 720 1280 \ + --video_format avi +``` + +Only generate video from folder of attention maps images: +``` +python video_generation.py --output_dir ../output/ \ + --resize 256 \ + --fps 60 \ + --video_only +``` + +Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook.
Self-attention from a Vision Transformer with 8x8 patches trained with DINO diff --git a/video_generation.py b/video_generation.py new file mode 100644 index 000000000..4160820d5 --- /dev/null +++ b/video_generation.py @@ -0,0 +1,303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import os +import glob +import sys +import argparse +import cv2 + +from tqdm import tqdm +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torchvision +from torchvision import transforms as pth_transforms +import numpy as np +from PIL import Image + +import utils +import vision_transformer as vits + + +def extract_frames_from_video(): + print("Extracting frames from", args.input_path) + vidcap = cv2.VideoCapture(args.input_path) + success, image = vidcap.read() + count = 0 + while success: + cv2.imwrite(os.path.join(args.output_dir, f"frame-{count:04}.jpg"), image) + success, image = vidcap.read() + count += 1 + + +def generate_video_from_images(format="mp4"): + print("Generating video...") + img_array = [] + # Change format to png if needed + for filename in tqdm(sorted(glob.glob(os.path.join(args.output_dir, "*.jpg")))): + with open(filename, "rb") as f: + img = Image.open(f) + img = img.convert("RGB") + size = (img.width, img.height) + + img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) + + if args.video_format == "avi": + out = cv2.VideoWriter( + "video.avi", cv2.VideoWriter_fourcc(*"XVID"), args.fps, size + ) + else: + out = cv2.VideoWriter( + "video.mp4", cv2.VideoWriter_fourcc(*"MP4V"), args.fps, size + ) + + for i in range(len(img_array)): + out.write(img_array[i]) + out.release() + print("Done") + + +def inference(images_folder_list: str): + for img_path in tqdm(images_folder_list): + with open(img_path, "rb") as f: + img = Image.open(f) + img = img.convert("RGB") + + if args.resize is not None: + transform = pth_transforms.Compose( + [ + pth_transforms.ToTensor(), + pth_transforms.Resize(args.resize), + pth_transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ), + ] + ) + else: + transform = pth_transforms.Compose( + [ + pth_transforms.ToTensor(), + pth_transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ), + ] + ) + + img = transform(img) + + # make the image divisible by the patch size + w, h = ( + img.shape[1] - img.shape[1] % args.patch_size, + img.shape[2] - img.shape[2] % args.patch_size, + ) + img = img[:, :w, :h].unsqueeze(0) + + w_featmap = img.shape[-2] // args.patch_size + h_featmap = img.shape[-1] // args.patch_size + + attentions = model.forward_selfattention(img.to(device)) + + nh = attentions.shape[1] # number of head + + # we keep only the output patch attention + attentions = attentions[0, :, 0, 1:].reshape(nh, -1) + + # we keep only a certain percentage of the mass + val, idx = torch.sort(attentions) + val /= torch.sum(val, dim=1, keepdim=True) + cumval = torch.cumsum(val, dim=1) + th_attn = cumval > (1 - args.threshold) + idx2 = torch.argsort(idx) + for head in range(nh): + th_attn[head] = th_attn[head][idx2[head]] + th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() + # interpolate + th_attn = ( + nn.functional.interpolate( + th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest" + )[0] + .cpu() + .numpy() + ) + + attentions = attentions.reshape(nh, w_featmap, h_featmap) + attentions = ( + nn.functional.interpolate( + attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest" + )[0] + .cpu() + .numpy() + ) + + # save attentions heatmaps + os.makedirs(args.output_dir, exist_ok=True) + fname = os.path.join(args.output_dir, "attn-" + os.path.basename(img_path)) + plt.imsave( + fname=fname, + arr=sum( + attentions[i] * 1 / attentions.shape[0] + for i in range(attentions.shape[0]) + ), + cmap="inferno", + format="jpg", + ) + + generate_video_from_images(args.video_format) + + +def load_model(): + # build model + model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) + for p in model.parameters(): + p.requires_grad = False + model.eval() + model.to(device) + if os.path.isfile(args.pretrained_weights): + state_dict = torch.load(args.pretrained_weights, map_location="cpu") + if args.checkpoint_key is not None and args.checkpoint_key in state_dict: + print(f"Take key {args.checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[args.checkpoint_key] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + print( + "Pretrained weights found at {} and loaded with msg: {}".format( + args.pretrained_weights, msg + ) + ) + else: + print( + "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." + ) + url = None + if args.arch == "deit_small" and args.patch_size == 16: + url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" + elif args.arch == "deit_small" and args.patch_size == 8: + url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper + elif args.arch == "vit_base" and args.patch_size == 16: + url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" + elif args.arch == "vit_base" and args.patch_size == 8: + url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" + if url is not None: + print( + "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." + ) + state_dict = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/dino/" + url + ) + model.load_state_dict(state_dict, strict=True) + else: + print( + "There is no reference weights available for this model => We use random weights." + ) + return model + + +def parse_args(): + parser = argparse.ArgumentParser("Visualize Self-Attention maps") + parser.add_argument( + "--arch", + default="deit_small", + type=str, + choices=["deit_tiny", "deit_small", "vit_base"], + help="Architecture (support only ViT atm).", + ) + parser.add_argument( + "--patch_size", default=8, type=int, help="Patch resolution of the model." + ) + parser.add_argument( + "--pretrained_weights", + default="", + type=str, + help="Path to pretrained weights to load.", + ) + parser.add_argument( + "--checkpoint_key", + default="teacher", + type=str, + help='Key to use in the checkpoint (example: "teacher")', + ) + parser.add_argument( + "--input_path", + default=None, + type=str, + help="""Path to a video file if you want to extract frames + or to a folder of images already extracted by yourself.""", + ) + parser.add_argument( + "--output_dir", + required=True, + type=str, + help="Path where to save visualizations and / or video.", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.6, + help="""We visualize masks + obtained by thresholding the self-attention maps to keep xx percent of the mass.""", + ) + parser.add_argument( + "--resize", + default=None, + type=int, + nargs="+", + help="""Apply a resize transformation to input image(s). Use if OOM error. + Usage (single or W H): --resize 512, --resize 720 1280""", + ) + parser.add_argument( + "--fps", + default=30.0, + type=float, + help="FPS of input / output video. Default: 30", + ) + parser.add_argument( + "--video_only", + action="store_true", + help="""Use this flag if you only want to generate a video and not all attention images. + If used, --output_dir must be set to the folder containing attention images.""", + ) + parser.add_argument( + "--video_format", + default="mp4", + type=str, + choices=["mp4", "avi"], + help="Format of generated video (mp4 or avi).", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = load_model() + + # If you only want a video + if args.video_only: + generate_video_from_images(args.video_format) + else: + # If input path isn't set + if args.input_path is None: + print(f"Provided input path {args.input_path} is non valid.") + sys.exit(1) + else: + # If input path exists + if os.path.exists(args.input_path): + # If input is a video file + if os.path.isfile(args.input_path): + extract_frames_from_video() + imgs_list = [ + os.path.join(args.output_dir, i) + for i in sorted(os.listdir(args.output_dir)) + ] + inference(imgs_list) + # If input is an images folder + if os.path.isdir(args.input_path): + imgs_list = [ + os.path.join(args.input_path, i) + for i in sorted(os.listdir(args.input_path)) + ] + inference(imgs_list) + # If input path doesn't exists + else: + print(f"Provided video file path {args.input_path} is non valid.") + sys.exit(1) From 4bb64c1ad2d57568177035513fc3f84f5262ccc4 Mon Sep 17 00:00:00 2001 From: William <46140458+aquadzn@users.noreply.github.com> Date: Sun, 2 May 2021 23:19:33 +0200 Subject: [PATCH 4/9] Update README.md Free of use video --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index bed08eeda..d300fae3d 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,8 @@ python visualize_attention.py ## Self-attention video generation You can generate videos like the one on the blog post with `video_generation.py`. +https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4 + Extract frames from input video and generate attention video: ``` python video_generation.py --input_path ../video.mp4 \ From a9b19584505ec22e24e84c607a26f83ba96ac8f9 Mon Sep 17 00:00:00 2001 From: aquadzn Date: Mon, 3 May 2021 11:11:24 +0200 Subject: [PATCH 5/9] Fixed generate_video_from_images() --- video_generation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/video_generation.py b/video_generation.py index 4160820d5..0d4584fb7 100644 --- a/video_generation.py +++ b/video_generation.py @@ -19,8 +19,10 @@ def extract_frames_from_video(): - print("Extracting frames from", args.input_path) vidcap = cv2.VideoCapture(args.input_path) + args.fps = vidcap.get(cv2.CAP_PROP_FPS) + print(f"Video: {args.input_path} ({args.fps} fps)") + print("Extracting frames...") success, image = vidcap.read() count = 0 while success: @@ -33,7 +35,7 @@ def generate_video_from_images(format="mp4"): print("Generating video...") img_array = [] # Change format to png if needed - for filename in tqdm(sorted(glob.glob(os.path.join(args.output_dir, "*.jpg")))): + for filename in tqdm(sorted(glob.glob(os.path.join(args.output_dir, "attn-*.jpg")))): with open(filename, "rb") as f: img = Image.open(f) img = img.convert("RGB") @@ -247,7 +249,7 @@ def parse_args(): "--fps", default=30.0, type=float, - help="FPS of input / output video. Default: 30", + help="FPS of input / output video. Automatically set if you extract frames from a video.", ) parser.add_argument( "--video_only", From 197efd49bb8a285d57419c9c27b0d36bcf896e90 Mon Sep 17 00:00:00 2001 From: aquadzn Date: Mon, 3 May 2021 15:29:48 +0200 Subject: [PATCH 6/9] Refactoring. Much cleaner --- README.md | 25 +-- video_generation.py | 443 +++++++++++++++++++++++++------------------- 2 files changed, 264 insertions(+), 204 deletions(-) diff --git a/README.md b/README.md index d300fae3d..28330efaf 100644 --- a/README.md +++ b/README.md @@ -154,25 +154,26 @@ https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb- Extract frames from input video and generate attention video: ``` -python video_generation.py --input_path ../video.mp4 \ - --output_dir ../output/ \ - --resize 256 \ +python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \ + --input_path input/video.mp4 \ + --output_path output/ \ + --fps 25 ``` -Use folder of frames already extracted and attention video: +Use folder of frames already extracted and generate attention video: ``` -python video_generation.py --input_path ../frames/ \ - --output_dir ../output/ \ - --resize 720 1280 \ - --video_format avi +python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \ + --input_path output/frames/ \ + --output_path output/ \ + --resize 256 \ ``` Only generate video from folder of attention maps images: ``` -python video_generation.py --output_dir ../output/ \ - --resize 256 \ - --fps 60 \ - --video_only +python video_generation.py --input_path output/attention \ + --output_path output/ \ + --video_only \ + --video_format avi ``` Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook. diff --git a/video_generation.py b/video_generation.py index 0d4584fb7..8c0723d3c 100644 --- a/video_generation.py +++ b/video_generation.py @@ -18,183 +18,270 @@ import vision_transformer as vits -def extract_frames_from_video(): - vidcap = cv2.VideoCapture(args.input_path) - args.fps = vidcap.get(cv2.CAP_PROP_FPS) - print(f"Video: {args.input_path} ({args.fps} fps)") - print("Extracting frames...") - success, image = vidcap.read() - count = 0 - while success: - cv2.imwrite(os.path.join(args.output_dir, f"frame-{count:04}.jpg"), image) +FOURCC = { + "mp4": cv2.VideoWriter_fourcc(*"MP4V"), + "avi": cv2.VideoWriter_fourcc(*"XVID"), +} +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +class VideoGenerator: + def __init__(self, args): + self.args = args + # self.model = None + # Don't need to load model if you only want a video + if not self.args.video_only: + self.model = self.__load_model() + + def run(self): + if self.args.input_path is None: + print(f"Provided input path {self.args.input_path} is non valid.") + sys.exit(1) + else: + if self.args.video_only: + self._generate_video_from_images( + self.args.input_path, self.args.output_path + ) + else: + # If input path exists + if os.path.exists(self.args.input_path): + # If input is a video file + if os.path.isfile(self.args.input_path): + frames_folder = os.path.join(self.args.output_path, "frames") + attention_folder = os.path.join( + self.args.output_path, "attention" + ) + + os.makedirs(frames_folder, exist_ok=True) + os.makedirs(attention_folder, exist_ok=True) + + self._extract_frames_from_video( + self.args.input_path, frames_folder + ) + + self._inference( + frames_folder, + attention_folder, + ) + + self._generate_video_from_images( + attention_folder, self.args.output_path + ) + + # If input is a folder of already extracted frames + if os.path.isdir(self.args.input_path): + attention_folder = os.path.join( + self.args.output_path, "attention" + ) + + os.makedirs(attention_folder, exist_ok=True) + + self._inference(self.args.input_path, attention_folder) + + self._generate_video_from_images( + attention_folder, self.args.output_path + ) + + # If input path doesn't exists + else: + print(f"Provided input path {self.args.input_path} doesn't exists.") + sys.exit(1) + + def _extract_frames_from_video(self, inp: str, out: str): + vidcap = cv2.VideoCapture(inp) + self.args.fps = vidcap.get(cv2.CAP_PROP_FPS) + + print(f"Video: {inp} ({self.args.fps} fps)") + print(f"Extracting frames to {out}") + success, image = vidcap.read() - count += 1 + count = 0 + while success: + cv2.imwrite( + os.path.join(out, f"frame-{count:04}.jpg"), + image, + ) + success, image = vidcap.read() + count += 1 + def _generate_video_from_images(self, inp: str, out: str): + img_array = [] + attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg"))) -def generate_video_from_images(format="mp4"): - print("Generating video...") - img_array = [] - # Change format to png if needed - for filename in tqdm(sorted(glob.glob(os.path.join(args.output_dir, "attn-*.jpg")))): - with open(filename, "rb") as f: + # Get size of the first image + with open(attention_images_list[0], "rb") as f: img = Image.open(f) img = img.convert("RGB") size = (img.width, img.height) - img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) - if args.video_format == "avi": - out = cv2.VideoWriter( - "video.avi", cv2.VideoWriter_fourcc(*"XVID"), args.fps, size - ) - else: + print(f"Generating video {size} to {out}") + + for filename in tqdm(attention_images_list[1:]): + with open(filename, "rb") as f: + img = Image.open(f) + img = img.convert("RGB") + img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) + out = cv2.VideoWriter( - "video.mp4", cv2.VideoWriter_fourcc(*"MP4V"), args.fps, size + os.path.join(out, "video." + self.args.video_format), + FOURCC[self.args.video_format], + self.args.fps, + size, ) - for i in range(len(img_array)): - out.write(img_array[i]) - out.release() - print("Done") + for i in range(len(img_array)): + out.write(img_array[i]) + out.release() + print("Done") + def _inference(self, inp: str, out: str): + print(f"Generating attention images to {out}") -def inference(images_folder_list: str): - for img_path in tqdm(images_folder_list): - with open(img_path, "rb") as f: - img = Image.open(f) - img = img.convert("RGB") + for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))): + with open(img_path, "rb") as f: + img = Image.open(f) + img = img.convert("RGB") - if args.resize is not None: - transform = pth_transforms.Compose( - [ - pth_transforms.ToTensor(), - pth_transforms.Resize(args.resize), - pth_transforms.Normalize( - (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) - ), - ] - ) - else: - transform = pth_transforms.Compose( - [ - pth_transforms.ToTensor(), - pth_transforms.Normalize( - (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) - ), - ] + if self.args.resize is not None: + transform = pth_transforms.Compose( + [ + pth_transforms.ToTensor(), + pth_transforms.Resize(self.args.resize), + pth_transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ), + ] + ) + else: + transform = pth_transforms.Compose( + [ + pth_transforms.ToTensor(), + pth_transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ), + ] + ) + + img = transform(img) + + # make the image divisible by the patch size + w, h = ( + img.shape[1] - img.shape[1] % self.args.patch_size, + img.shape[2] - img.shape[2] % self.args.patch_size, ) + img = img[:, :w, :h].unsqueeze(0) - img = transform(img) + w_featmap = img.shape[-2] // self.args.patch_size + h_featmap = img.shape[-1] // self.args.patch_size - # make the image divisible by the patch size - w, h = ( - img.shape[1] - img.shape[1] % args.patch_size, - img.shape[2] - img.shape[2] % args.patch_size, - ) - img = img[:, :w, :h].unsqueeze(0) - - w_featmap = img.shape[-2] // args.patch_size - h_featmap = img.shape[-1] // args.patch_size - - attentions = model.forward_selfattention(img.to(device)) - - nh = attentions.shape[1] # number of head - - # we keep only the output patch attention - attentions = attentions[0, :, 0, 1:].reshape(nh, -1) - - # we keep only a certain percentage of the mass - val, idx = torch.sort(attentions) - val /= torch.sum(val, dim=1, keepdim=True) - cumval = torch.cumsum(val, dim=1) - th_attn = cumval > (1 - args.threshold) - idx2 = torch.argsort(idx) - for head in range(nh): - th_attn[head] = th_attn[head][idx2[head]] - th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() - # interpolate - th_attn = ( - nn.functional.interpolate( - th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest" - )[0] - .cpu() - .numpy() - ) + attentions = self.model.forward_selfattention(img.to(DEVICE)) - attentions = attentions.reshape(nh, w_featmap, h_featmap) - attentions = ( - nn.functional.interpolate( - attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest" - )[0] - .cpu() - .numpy() - ) + nh = attentions.shape[1] # number of head - # save attentions heatmaps - os.makedirs(args.output_dir, exist_ok=True) - fname = os.path.join(args.output_dir, "attn-" + os.path.basename(img_path)) - plt.imsave( - fname=fname, - arr=sum( - attentions[i] * 1 / attentions.shape[0] - for i in range(attentions.shape[0]) - ), - cmap="inferno", - format="jpg", - ) + # we keep only the output patch attention + attentions = attentions[0, :, 0, 1:].reshape(nh, -1) - generate_video_from_images(args.video_format) - - -def load_model(): - # build model - model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) - for p in model.parameters(): - p.requires_grad = False - model.eval() - model.to(device) - if os.path.isfile(args.pretrained_weights): - state_dict = torch.load(args.pretrained_weights, map_location="cpu") - if args.checkpoint_key is not None and args.checkpoint_key in state_dict: - print(f"Take key {args.checkpoint_key} in provided checkpoint dict") - state_dict = state_dict[args.checkpoint_key] - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} - msg = model.load_state_dict(state_dict, strict=False) - print( - "Pretrained weights found at {} and loaded with msg: {}".format( - args.pretrained_weights, msg + # we keep only a certain percentage of the mass + val, idx = torch.sort(attentions) + val /= torch.sum(val, dim=1, keepdim=True) + cumval = torch.cumsum(val, dim=1) + th_attn = cumval > (1 - self.args.threshold) + idx2 = torch.argsort(idx) + for head in range(nh): + th_attn[head] = th_attn[head][idx2[head]] + th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() + # interpolate + th_attn = ( + nn.functional.interpolate( + th_attn.unsqueeze(0), + scale_factor=self.args.patch_size, + mode="nearest", + )[0] + .cpu() + .numpy() ) + + attentions = attentions.reshape(nh, w_featmap, h_featmap) + attentions = ( + nn.functional.interpolate( + attentions.unsqueeze(0), + scale_factor=self.args.patch_size, + mode="nearest", + )[0] + .cpu() + .numpy() + ) + + # save attentions heatmaps + fname = os.path.join(out, "attn-" + os.path.basename(img_path)) + plt.imsave( + fname=fname, + arr=sum( + attentions[i] * 1 / attentions.shape[0] + for i in range(attentions.shape[0]) + ), + cmap="inferno", + format="jpg", + ) + + def __load_model(self): + # build model + model = vits.__dict__[self.args.arch]( + patch_size=self.args.patch_size, num_classes=0 ) - else: - print( - "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." - ) - url = None - if args.arch == "deit_small" and args.patch_size == 16: - url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" - elif args.arch == "deit_small" and args.patch_size == 8: - url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper - elif args.arch == "vit_base" and args.patch_size == 16: - url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" - elif args.arch == "vit_base" and args.patch_size == 8: - url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" - if url is not None: + for p in model.parameters(): + p.requires_grad = False + model.eval() + model.to(DEVICE) + + if os.path.isfile(self.args.pretrained_weights): + state_dict = torch.load(self.args.pretrained_weights, map_location="cpu") + if ( + self.args.checkpoint_key is not None + and self.args.checkpoint_key in state_dict + ): + print( + f"Take key {self.args.checkpoint_key} in provided checkpoint dict" + ) + state_dict = state_dict[self.args.checkpoint_key] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) print( - "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." - ) - state_dict = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/dino/" + url + "Pretrained weights found at {} and loaded with msg: {}".format( + self.args.pretrained_weights, msg + ) ) - model.load_state_dict(state_dict, strict=True) else: print( - "There is no reference weights available for this model => We use random weights." + "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." ) - return model + url = None + if self.args.arch == "deit_small" and self.args.patch_size == 16: + url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" + elif self.args.arch == "deit_small" and self.args.patch_size == 8: + url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper + elif self.args.arch == "vit_base" and self.args.patch_size == 16: + url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" + elif self.args.arch == "vit_base" and self.args.patch_size == 8: + url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" + if url is not None: + print( + "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." + ) + state_dict = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/dino/" + url + ) + model.load_state_dict(state_dict, strict=True) + else: + print( + "There is no reference weights available for this model => We use random weights." + ) + return model def parse_args(): - parser = argparse.ArgumentParser("Visualize Self-Attention maps") + parser = argparse.ArgumentParser("Generation self-attention video") parser.add_argument( "--arch", default="deit_small", @@ -203,7 +290,7 @@ def parse_args(): help="Architecture (support only ViT atm).", ) parser.add_argument( - "--patch_size", default=8, type=int, help="Patch resolution of the model." + "--patch_size", default=8, type=int, help="Patch resolution of the self.model." ) parser.add_argument( "--pretrained_weights", @@ -219,16 +306,18 @@ def parse_args(): ) parser.add_argument( "--input_path", - default=None, + required=True, type=str, help="""Path to a video file if you want to extract frames - or to a folder of images already extracted by yourself.""", + or to a folder of images already extracted by yourself. + or to a folder of attention images.""", ) parser.add_argument( - "--output_dir", - required=True, + "--output_path", + default="./", type=str, - help="Path where to save visualizations and / or video.", + help="""Path to store a folder of frames and / or a folder of attention images. + and / or a final video. Default to current directory.""", ) parser.add_argument( "--threshold", @@ -245,18 +334,18 @@ def parse_args(): help="""Apply a resize transformation to input image(s). Use if OOM error. Usage (single or W H): --resize 512, --resize 720 1280""", ) + parser.add_argument( + "--video_only", + action="store_true", + help="""Use this flag if you only want to generate a video and not all attention images. + If used, --input_path must be set to the folder of attention images. Ex: ./attention/""", + ) parser.add_argument( "--fps", default=30.0, type=float, help="FPS of input / output video. Automatically set if you extract frames from a video.", ) - parser.add_argument( - "--video_only", - action="store_true", - help="""Use this flag if you only want to generate a video and not all attention images. - If used, --output_dir must be set to the folder containing attention images.""", - ) parser.add_argument( "--video_format", default="mp4", @@ -270,36 +359,6 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model = load_model() - - # If you only want a video - if args.video_only: - generate_video_from_images(args.video_format) - else: - # If input path isn't set - if args.input_path is None: - print(f"Provided input path {args.input_path} is non valid.") - sys.exit(1) - else: - # If input path exists - if os.path.exists(args.input_path): - # If input is a video file - if os.path.isfile(args.input_path): - extract_frames_from_video() - imgs_list = [ - os.path.join(args.output_dir, i) - for i in sorted(os.listdir(args.output_dir)) - ] - inference(imgs_list) - # If input is an images folder - if os.path.isdir(args.input_path): - imgs_list = [ - os.path.join(args.input_path, i) - for i in sorted(os.listdir(args.input_path)) - ] - inference(imgs_list) - # If input path doesn't exists - else: - print(f"Provided video file path {args.input_path} is non valid.") - sys.exit(1) + + vg = VideoGenerator(args) + vg.run() From f76e88a6b5f3a27ef7d87a40f1eb97f0b231d057 Mon Sep 17 00:00:00 2001 From: Fabio Uechi <308613+fabito@users.noreply.github.com> Date: Wed, 5 May 2021 07:45:12 +1200 Subject: [PATCH 7/9] add new argument num_classes to eval_knn --- eval_knn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eval_knn.py b/eval_knn.py index 1db6821ad..d761dffd3 100644 --- a/eval_knn.py +++ b/eval_knn.py @@ -186,6 +186,7 @@ def __getitem__(self, idx): distributed training; see https://pytorch.org/docs/stable/distributed.html""") parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) + parser.add_argument("--num_classes", default=1000, type=int, help="Num classes") args = parser.parse_args() utils.init_distributed_mode(args) @@ -212,6 +213,6 @@ def __getitem__(self, idx): print("Features are ready!\nStart the k-NN classification.") for k in args.nb_knn: top1, top5 = knn_classifier(train_features, train_labels, - test_features, test_labels, k, args.temperature) + test_features, test_labels, k, args.temperature, args.num_classes) print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") dist.barrier() From cb812ca98abead9f52bcfecfe4eed3583f5f6fa9 Mon Sep 17 00:00:00 2001 From: Syed Adeel Date: Tue, 4 May 2021 20:12:49 +0000 Subject: [PATCH 8/9] Add num_labels to eval_linear, change max_accuracy to best_acc --- eval_linear.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/eval_linear.py b/eval_linear.py index b17cf7309..649b95788 100644 --- a/eval_linear.py +++ b/eval_linear.py @@ -60,7 +60,7 @@ def eval_linear(args): # load weights to evaluate utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) - linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))) + linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels) linear_classifier = linear_classifier.cuda() linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) @@ -112,7 +112,7 @@ def eval_linear(args): } torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar")) print("Training of the supervised linear classifier on frozen features completed.\n" - "Top-1 test accuracy: {acc:.1f}".format(acc=max_accuracy)) + "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): @@ -165,14 +165,22 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool): output = linear_classifier(output) loss = nn.CrossEntropyLoss()(output, target) - acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + if linear_classifier.module.num_labels >= 5: + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + else: + acc1, = utils.accuracy(output, target, topk=(1,)) batch_size = inp.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + if linear_classifier.module.num_labels >= 5: + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + if linear_classifier.module.num_labels >= 5: + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + else: + print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @@ -180,6 +188,7 @@ class LinearClassifier(nn.Module): """Linear layer to train on top of frozen features""" def __init__(self, dim, num_labels=1000): super(LinearClassifier, self).__init__() + self.num_labels = num_labels self.linear = nn.Linear(dim, num_labels) self.linear.weight.data.normal_(mean=0.0, std=0.01) self.linear.bias.data.zero_() @@ -217,5 +226,6 @@ def forward(self, x): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') + parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') args = parser.parse_args() eval_linear(args) From 052a443ac37c92cf13ed4a2dfbfe436e97db8280 Mon Sep 17 00:00:00 2001 From: Fabio Uechi <308613+fabito@users.noreply.github.com> Date: Wed, 5 May 2021 08:31:21 +1200 Subject: [PATCH 9/9] add num_classes argument to eval_linear --- eval_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eval_linear.py b/eval_linear.py index b17cf7309..5658d1a91 100644 --- a/eval_linear.py +++ b/eval_linear.py @@ -60,7 +60,7 @@ def eval_linear(args): # load weights to evaluate utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) - linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))) + linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), args.num_classes) linear_classifier = linear_classifier.cuda() linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) @@ -217,5 +217,6 @@ def forward(self, x): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') + parser.add_argument("--num_classes", default=1000, type=int, help="Num classes") args = parser.parse_args() eval_linear(args)