From e671eaa5d62910c129690e0463afbe4586a4da72 Mon Sep 17 00:00:00 2001 From: valhassan Date: Tue, 19 Sep 2023 21:56:01 -0400 Subject: [PATCH] added HRNet+OCR, minor changes --- config/model/gdl_hrnet.yaml | 4 + models/hrnet/__init__.py | 0 models/hrnet/backbone.py | 455 ++++++++++++++++++++++++++++++++++++ models/hrnet/hrnet_ocr.py | 53 +++++ models/hrnet/ocr.py | 46 ++++ models/hrnet/ocr_modules.py | 138 +++++++++++ models/hrnet/utils.py | 41 ++++ tests/model/test_models.py | 5 +- train_segmentation.py | 22 +- 9 files changed, 758 insertions(+), 6 deletions(-) create mode 100644 config/model/gdl_hrnet.yaml create mode 100644 models/hrnet/__init__.py create mode 100644 models/hrnet/backbone.py create mode 100644 models/hrnet/hrnet_ocr.py create mode 100644 models/hrnet/ocr.py create mode 100644 models/hrnet/ocr_modules.py create mode 100644 models/hrnet/utils.py diff --git a/config/model/gdl_hrnet.yaml b/config/model/gdl_hrnet.yaml new file mode 100644 index 00000000..7df76c52 --- /dev/null +++ b/config/model/gdl_hrnet.yaml @@ -0,0 +1,4 @@ +# @package _global_ +model: + _target_: models.hrnet.hrnet_ocr.HRNet + pretrained: True \ No newline at end of file diff --git a/models/hrnet/__init__.py b/models/hrnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/hrnet/backbone.py b/models/hrnet/backbone.py new file mode 100644 index 00000000..280f91ca --- /dev/null +++ b/models/hrnet/backbone.py @@ -0,0 +1,455 @@ +""" +This HRNet implementation is modified from the following repository: +https://github.com/HRNet/HRNet-Semantic-Segmentation +""" + +import logging +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from models.hrnet.utils import ModelHelpers +from pytorch_lightning.utilities import rank_zero_only + +BatchNorm2d = ModelHelpers.batchnorm2d() +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + +__all__ = ['hrnetv2'] + + +model_urls = { + 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode='bilinear', + align_corners=False) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HRNetV2(nn.Module): + def __init__(self, n_class, **kwargs): + super(HRNetV2, self).__init__() + extra = { + 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'}, + 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'}, + 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'}, + 'FINAL_CONV_KERNEL': 1 + } + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(Bottleneck, 64, 64, 4) + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + self.high_level_ch = np.int_(np.sum(pre_stage_channels)) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + + x = torch.cat([x[0], x1, x2, x3], 1) + + # x = self.last_layer(x) + return x + + +def hrnetv2(num_of_classes, pretrained=False, **kwargs): + model = HRNetV2(n_class=num_of_classes, **kwargs) + if pretrained: + weights_file = ModelHelpers.load_url(model_urls['hrnetv2'], download=True) + model.load_state_dict(torch.load(weights_file, map_location=None), strict=False) + return model + +if __name__ == "__main__": + from torchinfo import summary + + model = hrnetv2(num_of_classes=4, pretrained=True) + batch_size = 8 + summary(model, input_size=(batch_size, 3, 512, 512)) + + diff --git a/models/hrnet/hrnet_ocr.py b/models/hrnet/hrnet_ocr.py new file mode 100644 index 00000000..bcb807d3 --- /dev/null +++ b/models/hrnet/hrnet_ocr.py @@ -0,0 +1,53 @@ +import logging +import torch.nn.functional as F + +from torch import nn +from models.hrnet.ocr import OCR +from models.hrnet.backbone import hrnetv2 + + + +class HRNet(nn.Module): + """ + High Resolution Network (hrnet_w48_v2) with Object Contextual Representation module + + """ + def __init__(self, pretrained, in_channels, classes) -> None: + super(HRNet, self).__init__() + if in_channels != 3: + logging.critical(F"HRNet model expects three channels input") + self.encoder = hrnetv2(num_of_classes=classes, pretrained=pretrained) + high_level_ch = self.encoder.high_level_ch + self.decoder = OCR(num_classes=classes, high_level_ch=high_level_ch) + + def forward(self, input): + high_level_features = self.encoder(input) + cls_out, aux_out, _ = self.decoder(high_level_features) + + input_size = input.shape[2:] + aux_out = F.interpolate(aux_out, size=input_size, mode='bilinear', align_corners=False) + cls_out = F.interpolate(cls_out, size=input_size, mode='bilinear', align_corners=False) + if self.training: + return cls_out, aux_out + else: + return cls_out + +if __name__ == "__main__": + import torch + from torchinfo import summary + + model = HRNet(pretrained=True, in_channels=3, classes=4) + model.to("cuda") + batch_size = 4 + + mask_tensor = torch.randn([batch_size, 3, 512, 512]).cuda() + + output, output_aux = model(mask_tensor) + for name, para in model.named_parameters(): + print("-"*20) + print(f"name: {name}") + print(f"requires_grad: {para.requires_grad}") + # print(output.shape) + # print(output_aux.shape) + # summary(model, input_size=(batch_size, 3, 512, 512)) + \ No newline at end of file diff --git a/models/hrnet/ocr.py b/models/hrnet/ocr.py new file mode 100644 index 00000000..c9f33647 --- /dev/null +++ b/models/hrnet/ocr.py @@ -0,0 +1,46 @@ + +from torch import nn +from models.hrnet.utils import ModelHelpers +from models.hrnet.ocr_modules import SpatialGather_Module, SpatialOCR_Module + + +BNReLU = ModelHelpers.BNReLU + +class OCR(nn.Module): + + def __init__(self, num_classes, high_level_ch) -> None: + super(OCR, self).__init__() + + ocr_mid_channels = 512 + ocr_key_channels = 256 + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(high_level_ch, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + BNReLU(ocr_mid_channels),) + self.ocr_gather_head = SpatialGather_Module(num_classes) + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + ) + + self.cls_head = nn.Conv2d(ocr_mid_channels, num_classes, + kernel_size=1, stride=1, padding=0,bias=True) + + self.aux_head = nn.Sequential(nn.Conv2d(high_level_ch, high_level_ch, + kernel_size=1, stride=1, padding=0), + BNReLU(high_level_ch), + nn.Conv2d(high_level_ch, num_classes, + kernel_size=1, stride=1, padding=0, bias=True)) + + def forward(self, high_level_features): + feats = self.conv3x3_ocr(high_level_features) + aux_out = self.aux_head(high_level_features) + context = self.ocr_gather_head(feats, aux_out) + ocr_feats = self.ocr_distri_head(feats, context) + cls_out = self.cls_head(ocr_feats) + return cls_out, aux_out, ocr_feats + + \ No newline at end of file diff --git a/models/hrnet/ocr_modules.py b/models/hrnet/ocr_modules.py new file mode 100644 index 00000000..975e46db --- /dev/null +++ b/models/hrnet/ocr_modules.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.hrnet.utils import ModelHelpers + +BNReLU = ModelHelpers.BNReLU + +# BatchNorm2d = ModelHelpers.batchnorm2d(bn_type="torch_bn") +# def BNReLU(ch): +# return nn.Sequential(BatchNorm2d(ch), nn.ReLU()) + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + + Output: + The correlation of every class map with every feature map + shape = [n, num_feats, num_classes, 1] + + + """ + def __init__(self, scale=1): + super(SpatialGather_Module, self).__init__() + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, = probs.size(0), probs.size(1) + + # each class image now a vector + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature + maps (save memory cost) + Return: + N X C X H X W + ''' + def __init__(self, in_channels, key_channels, scale=1): + super(ObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.key_channels), + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + BNReLU(self.in_channels), + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) + + return context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation + for each pixel. + """ + + def __init__(self, in_channels, key_channels, out_channels, scale=1, dropout=0.1): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock(in_channels, + key_channels, + scale) + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(sum([in_channels,in_channels]), out_channels, + kernel_size=1, padding=0, bias=False), + BNReLU(out_channels), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + return output + \ No newline at end of file diff --git a/models/hrnet/utils.py b/models/hrnet/utils.py new file mode 100644 index 00000000..7fe00208 --- /dev/null +++ b/models/hrnet/utils.py @@ -0,0 +1,41 @@ +import sys +import os +import logging +import torch.nn as nn +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve +import torch +from typing import Union, Optional +from pathlib import Path +from pytorch_lightning.utilities import rank_zero_only + +class ModelHelpers: + + @staticmethod + def batchnorm2d(bn_type: Union[str, "torch_sync_bn", "torch_bn"] = "torch_bn"): + if bn_type == "torch_bn": + return nn.BatchNorm2d + if bn_type == "torch_sync_bn": + return nn.SyncBatchNorm + + @staticmethod + def BNReLU(ch: torch.Tensor): + batchnorm = ModelHelpers.batchnorm2d() + return nn.Sequential( + batchnorm(ch), + nn.ReLU()) + + @rank_zero_only + @staticmethod + def load_url(url: str, download: bool): + model_dir = Path.home() / ".cache" / "torch" / "checkpoints" + if not model_dir.is_dir(): + Path.mkdir(model_dir, parents=True) + filename = url.split('/')[-1] + cached_file = model_dir.joinpath(filename) + if not cached_file.is_file() and download: + logging.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, str(cached_file)) + return cached_file \ No newline at end of file diff --git a/tests/model/test_models.py b/tests/model/test_models.py index 05db8129..5b4d5e28 100644 --- a/tests/model/test_models.py +++ b/tests/model/test_models.py @@ -33,7 +33,10 @@ def test_net(self) -> None: in_channels=3, out_classes=4, ) - output = model(rand_img) + if cfg.model._target_ == "models.hrnet.hrnet_ocr.HRNet": + output, output_aux = model(rand_img) + else: + output = model(rand_img) print(output.shape) diff --git a/train_segmentation.py b/train_segmentation.py index b244c929..335bbc09 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -297,6 +297,7 @@ def training(train_loader, device, scale, vis_params, + aux_output: bool = False, debug=False): """ Train the model and return the metrics of the training epoch @@ -327,7 +328,10 @@ def training(train_loader, # forward optimizer.zero_grad() - outputs = model(inputs) + if aux_output: + outputs, outputs_aux = model(inputs) + else: + outputs = model(inputs) # added for torchvision models that output an OrderedDict with outputs in 'out' key. # More info: https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/ if isinstance(outputs, OrderedDict): @@ -349,9 +353,13 @@ def training(train_loader, dataset='trn', ep_num=ep_idx + 1, scale=scale) - - loss = criterion(outputs, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) - + if aux_output: + loss_main = criterion(outputs, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) + loss_aux = criterion(outputs_aux, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) + loss = 0.4 * loss_aux + loss_main + else: + loss = criterion(outputs, labels) if num_classes > 1 else criterion(outputs, labels.unsqueeze(1).float()) + train_metrics['loss'].update(loss.item(), batch_size) if device.type == 'cuda' and debug: @@ -628,6 +636,7 @@ def train(cfg: DictConfig) -> None: # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH checkpoint = read_checkpoint(train_state_dict_path) + aux_output = False model = define_model( net_params=cfg.model, in_channels=num_bands, @@ -637,7 +646,9 @@ def train(cfg: DictConfig) -> None: checkpoint_dict=checkpoint, checkpoint_dict_strict_load=state_dict_strict ) - + + if cfg.model._target_ == "models.hrnet.hrnet_ocr.HRNet": + aux_output = True criterion = define_loss(loss_params=cfg.loss, class_weights=class_weights) criterion = criterion.to(device) optimizer = instantiate(cfg.optimizer, params=model.parameters()) @@ -717,6 +728,7 @@ def train(cfg: DictConfig) -> None: device=device, scale=scale, vis_params=vis_params, + aux_output=aux_output, debug=debug) if 'trn_log' in locals(): # only save the value if a tracker is setup trn_log.add_values(trn_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])