You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
4.3 KiB
104 lines
4.3 KiB
from typing import Dict |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.utils.registry import Registry |
|
from detectron2.layers import ShapeSpec |
|
|
|
from adet.layers import conv_with_kaiming_uniform |
|
|
|
|
|
BASIS_MODULE_REGISTRY = Registry("BASIS_MODULE") |
|
BASIS_MODULE_REGISTRY.__doc__ = """ |
|
Registry for basis module, which produces global bases from feature maps. |
|
|
|
The registered object will be called with `obj(cfg, input_shape)`. |
|
The call should return a `nn.Module` object. |
|
""" |
|
|
|
|
|
def build_basis_module(cfg, input_shape): |
|
name = cfg.MODEL.BASIS_MODULE.NAME |
|
return BASIS_MODULE_REGISTRY.get(name)(cfg, input_shape) |
|
|
|
|
|
@BASIS_MODULE_REGISTRY.register() |
|
class ProtoNet(nn.Module): |
|
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
|
""" |
|
TODO: support deconv and variable channel width |
|
""" |
|
# official protonet has a relu after each conv |
|
super().__init__() |
|
# fmt: off |
|
mask_dim = cfg.MODEL.BASIS_MODULE.NUM_BASES |
|
planes = cfg.MODEL.BASIS_MODULE.CONVS_DIM |
|
self.in_features = cfg.MODEL.BASIS_MODULE.IN_FEATURES |
|
self.loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON |
|
norm = cfg.MODEL.BASIS_MODULE.NORM |
|
num_convs = cfg.MODEL.BASIS_MODULE.NUM_CONVS |
|
self.visualize = cfg.MODEL.BLENDMASK.VISUALIZE |
|
# fmt: on |
|
|
|
feature_channels = {k: v.channels for k, v in input_shape.items()} |
|
|
|
conv_block = conv_with_kaiming_uniform(norm, True) # conv relu bn |
|
self.refine = nn.ModuleList() |
|
for in_feature in self.in_features: |
|
self.refine.append(conv_block( |
|
feature_channels[in_feature], planes, 3, 1)) |
|
tower = [] |
|
for i in range(num_convs): |
|
tower.append( |
|
conv_block(planes, planes, 3, 1)) |
|
tower.append( |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) |
|
tower.append( |
|
conv_block(planes, planes, 3, 1)) |
|
tower.append( |
|
nn.Conv2d(planes, mask_dim, 1)) |
|
self.add_module('tower', nn.Sequential(*tower)) |
|
|
|
if self.loss_on: |
|
# fmt: off |
|
self.common_stride = cfg.MODEL.BASIS_MODULE.COMMON_STRIDE |
|
num_classes = cfg.MODEL.BASIS_MODULE.NUM_CLASSES + 1 |
|
self.sem_loss_weight = cfg.MODEL.BASIS_MODULE.LOSS_WEIGHT |
|
# fmt: on |
|
|
|
inplanes = feature_channels[self.in_features[0]] |
|
self.seg_head = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=3, |
|
stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(planes), |
|
nn.ReLU(), |
|
nn.Conv2d(planes, planes, kernel_size=3, |
|
stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(planes), |
|
nn.ReLU(), |
|
nn.Conv2d(planes, num_classes, kernel_size=1, |
|
stride=1)) |
|
|
|
def forward(self, features, targets=None): |
|
for i, f in enumerate(self.in_features): |
|
if i == 0: |
|
x = self.refine[i](features[f]) |
|
else: |
|
x_p = self.refine[i](features[f]) |
|
x_p = F.interpolate(x_p, x.size()[2:], mode="bilinear", align_corners=False) |
|
# x_p = aligned_bilinear(x_p, x.size(3) // x_p.size(3)) |
|
x = x + x_p |
|
outputs = {"bases": [self.tower(x)]} |
|
losses = {} |
|
# auxiliary thing semantic loss |
|
if self.training and self.loss_on: |
|
sem_out = self.seg_head(features[self.in_features[0]]) |
|
# resize target to reduce memory |
|
gt_sem = targets.unsqueeze(1).float() |
|
gt_sem = F.interpolate( |
|
gt_sem, scale_factor=1 / self.common_stride) |
|
seg_loss = F.cross_entropy( |
|
sem_out, gt_sem.squeeze().long()) |
|
losses['loss_basis_sem'] = seg_loss * self.sem_loss_weight |
|
elif self.visualize and hasattr(self, "seg_head"): |
|
outputs["seg_thing_out"] = self.seg_head(features[self.in_features[0]]) |
|
return outputs, losses
|
|
|