In the chapter Class_activation_maps
Originally the functions were
im2fmap = nn.Sequential(*(list(model.model[:5].children()) + list(model.model[5][:2].children())))
def im2gradCAM(x):
model.eval()
logits = model(x)
heatmaps = []
activations = im2fmap(x)
print(activations.shape)
pred = logits.max(-1)[-1]
# get the model's prediction
model.zero_grad()
# compute gradients with respect to model's most confident logit
logits[0,pred].backward(retain_graph=True)
# get the gradients at the required featuremap location
# and take the avg gradient for every featuremap
pooled_grads = model.model[-7][1].weight.grad.data.mean((0,2,3))
# multiply each activation map with corresponding gradient average
for i in range(activations.shape[1]):
activations[:,i,:,:] *= pooled_grads[i]
# take the mean of all weighted activation maps
# (that has been weighted by avg. grad at each fmap)
heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
return heatmap, 'Uninfected' if pred.item() else 'Parasitized'
The paper assumes we get the activations and pooled_grads from the same convolution layer, but im2fmap and pooled_grads were pointing to different layers in the Resnet.
The correction is simple - point the pooled_grads
layer to the layer corresponding to im2fmap, i.e., - model.model[-6][1].weight.grad.data.mean((1,2,3))
Here's the correct version.
def im2gradCAM(x):
model.eval()
logits = model(x)
heatmaps = []
activations = im2fmap(x)
print(activations.shape)
pred = logits.max(-1)[-1]
# get the model's prediction
model.zero_grad()
# compute gradients with respect to model's most confident logit
logits[0,pred].backward(retain_graph=True)
# get the gradients at the required featuremap location
# and take the avg gradient for every featuremap
pooled_grads = model.model[-6][1].weight.grad.data.mean((1,2,3))
# multiply each activation map with corresponding gradient average
for i in range(activations.shape[1]):
activations[:,i,:,:] *= pooled_grads[i]
# take the mean of all weighted activation maps
# (that has been weighted by avg. grad at each fmap)
heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
return heatmap, 'Uninfected' if pred.item() else 'Parasitized'
This is now incorporated into the notebook