You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
490 lines
20 KiB
490 lines
20 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
from paddle import ParamAttr |
|
|
|
from paddlers.models.ppdet.core.workspace import register, create, load_config |
|
from paddlers.models.ppdet.modeling import ops |
|
from paddlers.models.ppdet.utils.checkpoint import load_pretrain_weight |
|
from paddlers.models.ppdet.utils.logger import setup_logger |
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
class DistillModel(nn.Layer): |
|
def __init__(self, cfg, slim_cfg): |
|
super(DistillModel, self).__init__() |
|
|
|
self.student_model = create(cfg.architecture) |
|
logger.debug('Load student model pretrain_weights:{}'.format( |
|
cfg.pretrain_weights)) |
|
load_pretrain_weight(self.student_model, cfg.pretrain_weights) |
|
|
|
slim_cfg = load_config(slim_cfg) |
|
self.teacher_model = create(slim_cfg.architecture) |
|
self.distill_loss = create(slim_cfg.distill_loss) |
|
logger.debug('Load teacher model pretrain_weights:{}'.format( |
|
slim_cfg.pretrain_weights)) |
|
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights) |
|
|
|
for param in self.teacher_model.parameters(): |
|
param.trainable = False |
|
|
|
def parameters(self): |
|
return self.student_model.parameters() |
|
|
|
def forward(self, inputs): |
|
if self.training: |
|
teacher_loss = self.teacher_model(inputs) |
|
student_loss = self.student_model(inputs) |
|
loss = self.distill_loss(self.teacher_model, self.student_model) |
|
student_loss['distill_loss'] = loss |
|
student_loss['teacher_loss'] = teacher_loss['loss'] |
|
student_loss['loss'] += student_loss['distill_loss'] |
|
return student_loss |
|
else: |
|
return self.student_model(inputs) |
|
|
|
|
|
class FGDDistillModel(nn.Layer): |
|
""" |
|
Build FGD distill model. |
|
Args: |
|
cfg: The student config. |
|
slim_cfg: The teacher and distill config. |
|
""" |
|
|
|
def __init__(self, cfg, slim_cfg): |
|
super(FGDDistillModel, self).__init__() |
|
|
|
self.is_inherit = True |
|
# build student model before load slim config |
|
self.student_model = create(cfg.architecture) |
|
self.arch = cfg.architecture |
|
stu_pretrain = cfg['pretrain_weights'] |
|
slim_cfg = load_config(slim_cfg) |
|
self.teacher_cfg = slim_cfg |
|
self.loss_cfg = slim_cfg |
|
tea_pretrain = cfg['pretrain_weights'] |
|
|
|
self.teacher_model = create(self.teacher_cfg.architecture) |
|
self.teacher_model.eval() |
|
|
|
for param in self.teacher_model.parameters(): |
|
param.trainable = False |
|
|
|
if 'pretrain_weights' in cfg and stu_pretrain: |
|
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: |
|
load_pretrain_weight(self.student_model, |
|
self.teacher_cfg.pretrain_weights) |
|
logger.debug( |
|
"Inheriting! loading teacher weights to student model!") |
|
|
|
load_pretrain_weight(self.student_model, stu_pretrain) |
|
|
|
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: |
|
load_pretrain_weight(self.teacher_model, |
|
self.teacher_cfg.pretrain_weights) |
|
|
|
self.fgd_loss_dic = self.build_loss( |
|
self.loss_cfg.distill_loss, |
|
name_list=self.loss_cfg['distill_loss_name']) |
|
|
|
def build_loss(self, |
|
cfg, |
|
name_list=[ |
|
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1', |
|
'neck_f_0' |
|
]): |
|
loss_func = dict() |
|
for idx, k in enumerate(name_list): |
|
loss_func[k] = create(cfg) |
|
return loss_func |
|
|
|
def forward(self, inputs): |
|
if self.training: |
|
s_body_feats = self.student_model.backbone(inputs) |
|
s_neck_feats = self.student_model.neck(s_body_feats) |
|
|
|
with paddle.no_grad(): |
|
t_body_feats = self.teacher_model.backbone(inputs) |
|
t_neck_feats = self.teacher_model.neck(t_body_feats) |
|
|
|
loss_dict = {} |
|
for idx, k in enumerate(self.fgd_loss_dic): |
|
loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx], |
|
t_neck_feats[idx], inputs) |
|
if self.arch == "RetinaNet": |
|
loss = self.student_model.head(s_neck_feats, inputs) |
|
elif self.arch == "PicoDet": |
|
head_outs = self.student_model.head( |
|
s_neck_feats, self.student_model.export_post_process) |
|
loss_gfl = self.student_model.head.get_loss(head_outs, inputs) |
|
total_loss = paddle.add_n(list(loss_gfl.values())) |
|
loss = {} |
|
loss.update(loss_gfl) |
|
loss.update({'loss': total_loss}) |
|
else: |
|
raise ValueError(f"Unsupported model {self.arch}") |
|
for k in loss_dict: |
|
loss['loss'] += loss_dict[k] |
|
loss[k] = loss_dict[k] |
|
return loss |
|
else: |
|
body_feats = self.student_model.backbone(inputs) |
|
neck_feats = self.student_model.neck(body_feats) |
|
head_outs = self.student_model.head(neck_feats) |
|
if self.arch == "RetinaNet": |
|
bbox, bbox_num = self.student_model.head.post_process( |
|
head_outs, inputs['im_shape'], inputs['scale_factor']) |
|
return {'bbox': bbox, 'bbox_num': bbox_num} |
|
elif self.arch == "PicoDet": |
|
head_outs = self.student_model.head( |
|
neck_feats, self.student_model.export_post_process) |
|
scale_factor = inputs['scale_factor'] |
|
bboxes, bbox_num = self.student_model.head.post_process( |
|
head_outs, |
|
scale_factor, |
|
export_nms=self.student_model.export_nms) |
|
return {'bbox': bboxes, 'bbox_num': bbox_num} |
|
else: |
|
raise ValueError(f"Unsupported model {self.arch}") |
|
|
|
|
|
@register |
|
class DistillYOLOv3Loss(nn.Layer): |
|
def __init__(self, weight=1000): |
|
super(DistillYOLOv3Loss, self).__init__() |
|
self.weight = weight |
|
|
|
def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj): |
|
loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx)) |
|
loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty)) |
|
loss_w = paddle.abs(sw - tw) |
|
loss_h = paddle.abs(sh - th) |
|
loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h]) |
|
weighted_loss = paddle.mean(loss * F.sigmoid(tobj)) |
|
return weighted_loss |
|
|
|
def obj_weighted_cls(self, scls, tcls, tobj): |
|
loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls)) |
|
weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj))) |
|
return weighted_loss |
|
|
|
def obj_loss(self, sobj, tobj): |
|
obj_mask = paddle.cast(tobj > 0., dtype="float32") |
|
obj_mask.stop_gradient = True |
|
loss = paddle.mean( |
|
ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask)) |
|
return loss |
|
|
|
def forward(self, teacher_model, student_model): |
|
teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs |
|
student_distill_pairs = student_model.yolo_head.loss.distill_pairs |
|
distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], [] |
|
for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs): |
|
distill_reg_loss.append( |
|
self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[ |
|
3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4])) |
|
distill_cls_loss.append( |
|
self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4])) |
|
distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4])) |
|
distill_reg_loss = paddle.add_n(distill_reg_loss) |
|
distill_cls_loss = paddle.add_n(distill_cls_loss) |
|
distill_obj_loss = paddle.add_n(distill_obj_loss) |
|
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss |
|
) * self.weight |
|
return loss |
|
|
|
|
|
def parameter_init(mode="kaiming", value=0.): |
|
if mode == "kaiming": |
|
weight_attr = paddle.nn.initializer.KaimingUniform() |
|
elif mode == "constant": |
|
weight_attr = paddle.nn.initializer.Constant(value=value) |
|
else: |
|
weight_attr = paddle.nn.initializer.KaimingUniform() |
|
|
|
weight_init = ParamAttr(initializer=weight_attr) |
|
return weight_init |
|
|
|
|
|
@register |
|
class FGDFeatureLoss(nn.Layer): |
|
""" |
|
The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py |
|
Paddle version of `Focal and Global Knowledge Distillation for Detectors` |
|
|
|
Args: |
|
student_channels(int): The number of channels in the student's FPN feature map. Default to 256. |
|
teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. |
|
temp (float, optional): The temperature coefficient. Defaults to 0.5. |
|
alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 |
|
beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 |
|
gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 |
|
lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 |
|
""" |
|
|
|
def __init__(self, |
|
student_channels=256, |
|
teacher_channels=256, |
|
temp=0.5, |
|
alpha_fgd=0.001, |
|
beta_fgd=0.0005, |
|
gamma_fgd=0.001, |
|
lambda_fgd=0.000005): |
|
super(FGDFeatureLoss, self).__init__() |
|
self.temp = temp |
|
self.alpha_fgd = alpha_fgd |
|
self.beta_fgd = beta_fgd |
|
self.gamma_fgd = gamma_fgd |
|
self.lambda_fgd = lambda_fgd |
|
|
|
kaiming_init = parameter_init("kaiming") |
|
zeros_init = parameter_init("constant", 0.0) |
|
|
|
if student_channels != teacher_channels: |
|
self.align = nn.Conv2D( |
|
student_channels, |
|
teacher_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
weight_attr=kaiming_init) |
|
student_channels = teacher_channels |
|
else: |
|
self.align = None |
|
|
|
self.conv_mask_s = nn.Conv2D( |
|
student_channels, 1, kernel_size=1, weight_attr=kaiming_init) |
|
self.conv_mask_t = nn.Conv2D( |
|
teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) |
|
|
|
self.stu_conv_block = nn.Sequential( |
|
nn.Conv2D( |
|
student_channels, |
|
student_channels // 2, |
|
kernel_size=1, |
|
weight_attr=zeros_init), |
|
nn.LayerNorm([student_channels // 2, 1, 1]), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
student_channels // 2, |
|
student_channels, |
|
kernel_size=1, |
|
weight_attr=zeros_init)) |
|
self.tea_conv_block = nn.Sequential( |
|
nn.Conv2D( |
|
teacher_channels, |
|
teacher_channels // 2, |
|
kernel_size=1, |
|
weight_attr=zeros_init), |
|
nn.LayerNorm([teacher_channels // 2, 1, 1]), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
teacher_channels // 2, |
|
teacher_channels, |
|
kernel_size=1, |
|
weight_attr=zeros_init)) |
|
|
|
def spatial_channel_attention(self, x, t=0.5): |
|
shape = paddle.shape(x) |
|
N, C, H, W = shape |
|
|
|
_f = paddle.abs(x) |
|
spatial_map = paddle.reshape( |
|
paddle.mean( |
|
_f, axis=1, keepdim=True) / t, [N, -1]) |
|
spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W |
|
spatial_att = paddle.reshape(spatial_map, [N, H, W]) |
|
|
|
channel_map = paddle.mean( |
|
paddle.mean( |
|
_f, axis=2, keepdim=False), axis=2, keepdim=False) |
|
channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C |
|
return [spatial_att, channel_att] |
|
|
|
def spatial_pool(self, x, mode="teacher"): |
|
batch, channel, width, height = x.shape |
|
x_copy = x |
|
x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) |
|
x_copy = x_copy.unsqueeze(1) |
|
if mode.lower() == "student": |
|
context_mask = self.conv_mask_s(x) |
|
else: |
|
context_mask = self.conv_mask_t(x) |
|
|
|
context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) |
|
context_mask = F.softmax(context_mask, axis=2) |
|
context_mask = context_mask.unsqueeze(-1) |
|
context = paddle.matmul(x_copy, context_mask) |
|
context = paddle.reshape(context, [batch, channel, 1, 1]) |
|
|
|
return context |
|
|
|
def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, |
|
tea_spatial_att): |
|
def _func(a, b): |
|
return paddle.sum(paddle.abs(a - b)) / len(a) |
|
|
|
mask_loss = _func(stu_channel_att, tea_channel_att) + _func( |
|
stu_spatial_att, tea_spatial_att) |
|
|
|
return mask_loss |
|
|
|
def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, |
|
tea_channel_att, tea_spatial_att): |
|
|
|
Mask_fg = Mask_fg.unsqueeze(axis=1) |
|
Mask_bg = Mask_bg.unsqueeze(axis=1) |
|
|
|
tea_channel_att = tea_channel_att.unsqueeze(axis=-1) |
|
tea_channel_att = tea_channel_att.unsqueeze(axis=-1) |
|
|
|
tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) |
|
|
|
fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) |
|
fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) |
|
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) |
|
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) |
|
|
|
fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) |
|
fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) |
|
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) |
|
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) |
|
|
|
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg) |
|
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg) |
|
|
|
return fg_loss, bg_loss |
|
|
|
def relation_loss(self, stu_feature, tea_feature): |
|
context_s = self.spatial_pool(stu_feature, "student") |
|
context_t = self.spatial_pool(tea_feature, "teacher") |
|
|
|
out_s = stu_feature + self.stu_conv_block(context_s) |
|
out_t = tea_feature + self.tea_conv_block(context_t) |
|
|
|
rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) |
|
|
|
return rela_loss |
|
|
|
def mask_value(self, mask, xl, xr, yl, yr, value): |
|
mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) |
|
return mask |
|
|
|
def forward(self, stu_feature, tea_feature, inputs): |
|
"""Forward function. |
|
Args: |
|
stu_feature(Tensor): Bs*C*H*W, student's feature map |
|
tea_feature(Tensor): Bs*C*H*W, teacher's feature map |
|
inputs: The inputs with gt bbox and input shape info. |
|
""" |
|
assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ |
|
f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' |
|
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( |
|
), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." |
|
gt_bboxes = inputs['gt_bbox'] |
|
ins_shape = [ |
|
inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) |
|
] |
|
|
|
index_gt = [] |
|
for i in range(len(gt_bboxes)): |
|
if gt_bboxes[i].size > 2: |
|
index_gt.append(i) |
|
# only distill feature with labeled GTbox |
|
if len(index_gt) != len(gt_bboxes): |
|
index_gt_t = paddle.to_tensor(index_gt) |
|
preds_S = paddle.index_select(preds_S, index_gt_t) |
|
preds_T = paddle.index_select(preds_T, index_gt_t) |
|
|
|
ins_shape = [ins_shape[c] for c in index_gt] |
|
gt_bboxes = [gt_bboxes[c] for c in index_gt] |
|
assert len(gt_bboxes) == preds_T.shape[ |
|
0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]." |
|
|
|
if self.align is not None: |
|
stu_feature = self.align(stu_feature) |
|
|
|
N, C, H, W = stu_feature.shape |
|
|
|
tea_spatial_att, tea_channel_att = self.spatial_channel_attention( |
|
tea_feature, self.temp) |
|
stu_spatial_att, stu_channel_att = self.spatial_channel_attention( |
|
stu_feature, self.temp) |
|
|
|
Mask_fg = paddle.zeros(tea_spatial_att.shape) |
|
Mask_bg = paddle.ones_like(tea_spatial_att) |
|
one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) |
|
zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) |
|
Mask_fg.stop_gradient = True |
|
Mask_bg.stop_gradient = True |
|
one_tmp.stop_gradient = True |
|
zero_tmp.stop_gradient = True |
|
|
|
wmin, wmax, hmin, hmax, area = [], [], [], [], [] |
|
|
|
for i in range(N): |
|
tmp_box = paddle.ones_like(gt_bboxes[i]) |
|
tmp_box.stop_gradient = True |
|
tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W |
|
tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W |
|
tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H |
|
tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H |
|
|
|
zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") |
|
ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") |
|
zero.stop_gradient = True |
|
ones.stop_gradient = True |
|
|
|
wmin.append( |
|
paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) |
|
wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) |
|
hmin.append( |
|
paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) |
|
hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) |
|
|
|
area_recip = 1.0 / ( |
|
hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( |
|
wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) |
|
|
|
for j in range(len(gt_bboxes[i])): |
|
Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], |
|
hmax[i][j] + 1, wmin[i][j], |
|
wmax[i][j] + 1, area_recip[0][j]) |
|
|
|
Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) |
|
|
|
if paddle.sum(Mask_bg[i]): |
|
Mask_bg[i] /= paddle.sum(Mask_bg[i]) |
|
|
|
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, |
|
Mask_bg, tea_channel_att, |
|
tea_spatial_att) |
|
mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, |
|
stu_spatial_att, tea_spatial_att) |
|
rela_loss = self.relation_loss(stu_feature, tea_feature) |
|
|
|
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ |
|
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss |
|
|
|
return loss
|
|
|