-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualize_attention.py
59 lines (47 loc) · 1.31 KB
/
visualize_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from PIL import Image
import requests
import numpy as np
from io import BytesIO
import torch
import matplotlib.pyplot as plt
from torch import nn
from torchvision.models import resnet34, resnet50
from torchvision.models.resnet import ResNet, BasicBlock
import torchvision.transforms as T
import torch.nn.functional as F
base_resnet34 = resnet34(pretrained=True)
class ResNet34AT(ResNet):
"""Attention maps of ResNet-34.
Overloaded ResNet model to return attention maps.
"""
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
g0 = self.layer1(x)
g1 = self.layer2(g0)
g2 = self.layer3(g1)
g3 = self.layer4(g2)
return [g.pow(2).mean(1) for g in (g0, g1, g2, g3)]
model = ResNet34AT(BasicBlock, [3, 4, 6, 3])
print(model)
exit(0)
model.load_state_dict(base_resnet34.state_dict())
import cv2
im = cv2.imread('img.jpg')
plt.imshow(im)
tr_center_crop = T.Compose([
T.ToPILImage(),
T.Resize(256),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
model.eval()
with torch.no_grad():
x = tr_center_crop(im).unsqueeze(0)
gs = model(x)
for i, g in enumerate(gs):
plt.imshow(g[0], interpolation='bicubic')
plt.title(f'g{i}')
plt.show()