You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
375 lines
13 KiB
375 lines
13 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
""" |
|
This code is based on https://github.com/PeizeSun/SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py |
|
Ths copyright of PeizeSun/SparseR-CNN is as follows: |
|
MIT License [see LICENSE for details] |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import math |
|
import copy |
|
import paddle |
|
import paddle.nn as nn |
|
|
|
from paddlers.models.ppdet.core.workspace import register |
|
from paddlers.models.ppdet.modeling.heads.roi_extractor import RoIAlign |
|
from paddlers.models.ppdet.modeling.bbox_utils import delta2bbox |
|
from .. import initializer as init |
|
|
|
_DEFAULT_SCALE_CLAMP = math.log(100000. / 16) |
|
|
|
|
|
class DynamicConv(nn.Layer): |
|
def __init__( |
|
self, |
|
head_hidden_dim, |
|
head_dim_dynamic, |
|
head_num_dynamic, ): |
|
super().__init__() |
|
|
|
self.hidden_dim = head_hidden_dim |
|
self.dim_dynamic = head_dim_dynamic |
|
self.num_dynamic = head_num_dynamic |
|
self.num_params = self.hidden_dim * self.dim_dynamic |
|
self.dynamic_layer = nn.Linear(self.hidden_dim, |
|
self.num_dynamic * self.num_params) |
|
|
|
self.norm1 = nn.LayerNorm(self.dim_dynamic) |
|
self.norm2 = nn.LayerNorm(self.hidden_dim) |
|
|
|
self.activation = nn.ReLU() |
|
|
|
pooler_resolution = 7 |
|
num_output = self.hidden_dim * pooler_resolution**2 |
|
self.out_layer = nn.Linear(num_output, self.hidden_dim) |
|
self.norm3 = nn.LayerNorm(self.hidden_dim) |
|
|
|
def forward(self, pro_features, roi_features): |
|
''' |
|
pro_features: (1, N * nr_boxes, self.d_model) |
|
roi_features: (49, N * nr_boxes, self.d_model) |
|
''' |
|
features = roi_features.transpose(perm=[1, 0, 2]) |
|
parameters = self.dynamic_layer(pro_features).transpose(perm=[1, 0, 2]) |
|
|
|
param1 = parameters[:, :, :self.num_params].reshape( |
|
[-1, self.hidden_dim, self.dim_dynamic]) |
|
param2 = parameters[:, :, self.num_params:].reshape( |
|
[-1, self.dim_dynamic, self.hidden_dim]) |
|
|
|
features = paddle.bmm(features, param1) |
|
features = self.norm1(features) |
|
features = self.activation(features) |
|
|
|
features = paddle.bmm(features, param2) |
|
features = self.norm2(features) |
|
features = self.activation(features) |
|
|
|
features = features.flatten(1) |
|
features = self.out_layer(features) |
|
features = self.norm3(features) |
|
features = self.activation(features) |
|
|
|
return features |
|
|
|
|
|
class RCNNHead(nn.Layer): |
|
def __init__( |
|
self, |
|
d_model, |
|
num_classes, |
|
dim_feedforward, |
|
nhead, |
|
dropout, |
|
head_cls, |
|
head_reg, |
|
head_dim_dynamic, |
|
head_num_dynamic, |
|
scale_clamp: float=_DEFAULT_SCALE_CLAMP, |
|
bbox_weights=(2.0, 2.0, 1.0, 1.0), ): |
|
super().__init__() |
|
|
|
self.d_model = d_model |
|
|
|
# dynamic. |
|
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout) |
|
self.inst_interact = DynamicConv(d_model, head_dim_dynamic, |
|
head_num_dynamic) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.activation = nn.ReLU() |
|
|
|
# cls. |
|
num_cls = head_cls |
|
cls_module = list() |
|
for _ in range(num_cls): |
|
cls_module.append(nn.Linear(d_model, d_model, bias_attr=False)) |
|
cls_module.append(nn.LayerNorm(d_model)) |
|
cls_module.append(nn.ReLU()) |
|
self.cls_module = nn.LayerList(cls_module) |
|
|
|
# reg. |
|
num_reg = head_reg |
|
reg_module = list() |
|
for _ in range(num_reg): |
|
reg_module.append(nn.Linear(d_model, d_model, bias_attr=False)) |
|
reg_module.append(nn.LayerNorm(d_model)) |
|
reg_module.append(nn.ReLU()) |
|
self.reg_module = nn.LayerList(reg_module) |
|
|
|
# pred. |
|
self.class_logits = nn.Linear(d_model, num_classes) |
|
self.bboxes_delta = nn.Linear(d_model, 4) |
|
self.scale_clamp = scale_clamp |
|
self.bbox_weights = bbox_weights |
|
|
|
def forward(self, features, bboxes, pro_features, pooler): |
|
""" |
|
:param bboxes: (N, nr_boxes, 4) |
|
:param pro_features: (N, nr_boxes, d_model) |
|
""" |
|
|
|
N, nr_boxes = bboxes.shape[:2] |
|
|
|
proposal_boxes = list() |
|
for b in range(N): |
|
proposal_boxes.append(bboxes[b]) |
|
roi_num = paddle.full([N], nr_boxes).astype("int32") |
|
|
|
roi_features = pooler(features, proposal_boxes, roi_num) |
|
roi_features = roi_features.reshape( |
|
[N * nr_boxes, self.d_model, -1]).transpose(perm=[2, 0, 1]) |
|
|
|
# self_att. |
|
pro_features = pro_features.reshape([N, nr_boxes, self.d_model]) |
|
pro_features2 = self.self_attn( |
|
pro_features, pro_features, value=pro_features) |
|
pro_features = pro_features.transpose(perm=[1, 0, 2]) + self.dropout1( |
|
pro_features2.transpose(perm=[1, 0, 2])) |
|
pro_features = self.norm1(pro_features) |
|
|
|
# inst_interact. |
|
pro_features = pro_features.reshape( |
|
[nr_boxes, N, self.d_model]).transpose(perm=[1, 0, 2]).reshape( |
|
[1, N * nr_boxes, self.d_model]) |
|
pro_features2 = self.inst_interact(pro_features, roi_features) |
|
pro_features = pro_features + self.dropout2(pro_features2) |
|
obj_features = self.norm2(pro_features) |
|
|
|
# obj_feature. |
|
obj_features2 = self.linear2( |
|
self.dropout(self.activation(self.linear1(obj_features)))) |
|
obj_features = obj_features + self.dropout3(obj_features2) |
|
obj_features = self.norm3(obj_features) |
|
|
|
fc_feature = obj_features.transpose(perm=[1, 0, 2]).reshape( |
|
[N * nr_boxes, -1]) |
|
cls_feature = fc_feature.clone() |
|
reg_feature = fc_feature.clone() |
|
for cls_layer in self.cls_module: |
|
cls_feature = cls_layer(cls_feature) |
|
for reg_layer in self.reg_module: |
|
reg_feature = reg_layer(reg_feature) |
|
class_logits = self.class_logits(cls_feature) |
|
bboxes_deltas = self.bboxes_delta(reg_feature) |
|
pred_bboxes = delta2bbox(bboxes_deltas, |
|
bboxes.reshape([-1, 4]), self.bbox_weights) |
|
|
|
return class_logits.reshape([N, nr_boxes, -1]), pred_bboxes.reshape( |
|
[N, nr_boxes, -1]), obj_features |
|
|
|
|
|
@register |
|
class SparseRCNNHead(nn.Layer): |
|
''' |
|
SparsercnnHead |
|
Args: |
|
roi_input_shape (list[ShapeSpec]): The output shape of fpn |
|
num_classes (int): Number of classes, |
|
head_hidden_dim (int): The param of MultiHeadAttention, |
|
head_dim_feedforward (int): The param of MultiHeadAttention, |
|
nhead (int): The param of MultiHeadAttention, |
|
head_dropout (float): The p of dropout, |
|
head_cls (int): The number of class head, |
|
head_reg (int): The number of regressionhead, |
|
head_num_dynamic (int): The number of DynamicConv's param, |
|
head_num_heads (int): The number of RCNNHead, |
|
deep_supervision (int): wheather supervise the intermediate results, |
|
num_proposals (int): the number of proposals boxes and features |
|
''' |
|
__inject__ = ['loss_func'] |
|
__shared__ = ['num_classes'] |
|
|
|
def __init__( |
|
self, |
|
head_hidden_dim, |
|
head_dim_feedforward, |
|
nhead, |
|
head_dropout, |
|
head_cls, |
|
head_reg, |
|
head_dim_dynamic, |
|
head_num_dynamic, |
|
head_num_heads, |
|
deep_supervision, |
|
num_proposals, |
|
num_classes=80, |
|
loss_func="SparseRCNNLoss", |
|
roi_input_shape=None, ): |
|
super().__init__() |
|
|
|
# Build RoI. |
|
box_pooler = self._init_box_pooler(roi_input_shape) |
|
self.box_pooler = box_pooler |
|
|
|
# Build heads. |
|
rcnn_head = RCNNHead( |
|
head_hidden_dim, |
|
num_classes, |
|
head_dim_feedforward, |
|
nhead, |
|
head_dropout, |
|
head_cls, |
|
head_reg, |
|
head_dim_dynamic, |
|
head_num_dynamic, ) |
|
self.head_series = nn.LayerList( |
|
[copy.deepcopy(rcnn_head) for i in range(head_num_heads)]) |
|
self.return_intermediate = deep_supervision |
|
|
|
self.num_classes = num_classes |
|
|
|
# build init proposal |
|
self.init_proposal_features = nn.Embedding(num_proposals, |
|
head_hidden_dim) |
|
self.init_proposal_boxes = nn.Embedding(num_proposals, 4) |
|
|
|
self.lossfunc = loss_func |
|
|
|
# Init parameters. |
|
init.reset_initialized_parameter(self) |
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
# init all parameters. |
|
prior_prob = 0.01 |
|
bias_value = -math.log((1 - prior_prob) / prior_prob) |
|
|
|
for m in self.sublayers(): |
|
if isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight, reverse=True) |
|
elif not isinstance(m, nn.Embedding) and hasattr( |
|
m, "weight") and m.weight.dim() > 1: |
|
init.xavier_normal_(m.weight, reverse=False) |
|
|
|
if hasattr(m, "bias") and m.bias is not None and m.bias.shape[ |
|
-1] == self.num_classes: |
|
init.constant_(m.bias, bias_value) |
|
|
|
init_bboxes = paddle.empty_like(self.init_proposal_boxes.weight) |
|
init_bboxes[:, :2] = 0.5 |
|
init_bboxes[:, 2:] = 1.0 |
|
self.init_proposal_boxes.weight.set_value(init_bboxes) |
|
|
|
@staticmethod |
|
def _init_box_pooler(input_shape): |
|
|
|
pooler_resolution = 7 |
|
sampling_ratio = 2 |
|
|
|
if input_shape is not None: |
|
pooler_scales = tuple(1.0 / input_shape[k].stride |
|
for k in range(len(input_shape))) |
|
in_channels = [ |
|
input_shape[f].channels for f in range(len(input_shape)) |
|
] |
|
end_level = len(input_shape) - 1 |
|
# Check all channel counts are equal |
|
assert len(set(in_channels)) == 1, in_channels |
|
else: |
|
pooler_scales = [1.0 / 4.0, 1.0 / 8.0, 1.0 / 16.0, 1.0 / 32.0] |
|
end_level = 3 |
|
|
|
box_pooler = RoIAlign( |
|
resolution=pooler_resolution, |
|
spatial_scale=pooler_scales, |
|
sampling_ratio=sampling_ratio, |
|
end_level=end_level, |
|
aligned=True) |
|
return box_pooler |
|
|
|
def forward(self, features, input_whwh): |
|
|
|
bs = len(features[0]) |
|
bboxes = box_cxcywh_to_xyxy(self.init_proposal_boxes.weight.clone( |
|
)).unsqueeze(0) |
|
bboxes = bboxes * input_whwh.unsqueeze(-2) |
|
|
|
init_features = self.init_proposal_features.weight.unsqueeze(0).tile( |
|
[1, bs, 1]) |
|
proposal_features = init_features.clone() |
|
|
|
inter_class_logits = [] |
|
inter_pred_bboxes = [] |
|
|
|
for rcnn_head in self.head_series: |
|
class_logits, pred_bboxes, proposal_features = rcnn_head( |
|
features, bboxes, proposal_features, self.box_pooler) |
|
|
|
if self.return_intermediate: |
|
inter_class_logits.append(class_logits) |
|
inter_pred_bboxes.append(pred_bboxes) |
|
bboxes = pred_bboxes.detach() |
|
|
|
output = { |
|
'pred_logits': inter_class_logits[-1], |
|
'pred_boxes': inter_pred_bboxes[-1] |
|
} |
|
if self.return_intermediate: |
|
output['aux_outputs'] = [{ |
|
'pred_logits': a, |
|
'pred_boxes': b |
|
} for a, b in zip(inter_class_logits[:-1], inter_pred_bboxes[:-1])] |
|
|
|
return output |
|
|
|
def get_loss(self, outputs, targets): |
|
losses = self.lossfunc(outputs, targets) |
|
weight_dict = self.lossfunc.weight_dict |
|
|
|
for k in losses.keys(): |
|
if k in weight_dict: |
|
losses[k] *= weight_dict[k] |
|
|
|
return losses |
|
|
|
|
|
def box_cxcywh_to_xyxy(x): |
|
x_c, y_c, w, h = x.unbind(-1) |
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] |
|
return paddle.stack(b, axis=-1)
|
|
|