-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
30 lines (22 loc) · 898 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
We will use the FCN ResNet50 from the PyTorch model. We will not
use any pretrained weights. Training from scratch.
"""
import torchvision.models as models
import torch.nn as nn
def model(pretrained, requires_grad):
model = models.segmentation.deeplabv3_resnet101(
pretrained=pretrained, progress=True)
if requires_grad == True:
for param in model.parameters():
param.requires_grad = True
elif requires_grad == False:
for param in model.parameters():
param.requires_grad = False
# change the classification FCNHead and make it learnable
model.classifier[4] = nn.Conv2d(256, 15, kernel_size=(1, 1))
# change the aux_classification FCNHead and make it learnable
model.aux_classifier[4] = nn.Conv2d(256, 15, kernel_size=(1, 1))
return model
model = model(pretrained=True, requires_grad=True)
print(model)