You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

490 lines
20 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddlers.models.ppdet.core.workspace import register, create, load_config
from paddlers.models.ppdet.modeling import ops
from paddlers.models.ppdet.utils.checkpoint import load_pretrain_weight
from paddlers.models.ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
class DistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg):
super(DistillModel, self).__init__()
self.student_model = create(cfg.architecture)
logger.debug('Load student model pretrain_weights:{}'.format(
cfg.pretrain_weights))
load_pretrain_weight(self.student_model, cfg.pretrain_weights)
slim_cfg = load_config(slim_cfg)
self.teacher_model = create(slim_cfg.architecture)
self.distill_loss = create(slim_cfg.distill_loss)
logger.debug('Load teacher model pretrain_weights:{}'.format(
slim_cfg.pretrain_weights))
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
for param in self.teacher_model.parameters():
param.trainable = False
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
teacher_loss = self.teacher_model(inputs)
student_loss = self.student_model(inputs)
loss = self.distill_loss(self.teacher_model, self.student_model)
student_loss['distill_loss'] = loss
student_loss['teacher_loss'] = teacher_loss['loss']
student_loss['loss'] += student_loss['distill_loss']
return student_loss
else:
return self.student_model(inputs)
class FGDDistillModel(nn.Layer):
"""
Build FGD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__()
self.is_inherit = True
# build student model before load slim config
self.student_model = create(cfg.architecture)
self.arch = cfg.architecture
stu_pretrain = cfg['pretrain_weights']
slim_cfg = load_config(slim_cfg)
self.teacher_cfg = slim_cfg
self.loss_cfg = slim_cfg
tea_pretrain = cfg['pretrain_weights']
self.teacher_model = create(self.teacher_cfg.architecture)
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
if 'pretrain_weights' in cfg and stu_pretrain:
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.student_model,
self.teacher_cfg.pretrain_weights)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.teacher_model,
self.teacher_cfg.pretrain_weights)
self.fgd_loss_dic = self.build_loss(
self.loss_cfg.distill_loss,
name_list=self.loss_cfg['distill_loss_name'])
def build_loss(self,
cfg,
name_list=[
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
'neck_f_0'
]):
loss_func = dict()
for idx, k in enumerate(name_list):
loss_func[k] = create(cfg)
return loss_func
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
loss_dict = {}
for idx, k in enumerate(self.fgd_loss_dic):
loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
t_neck_feats[idx], inputs)
if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
s_neck_feats, self.student_model.export_post_process)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
total_loss = paddle.add_n(list(loss_gfl.values()))
loss = {}
loss.update(loss_gfl)
loss.update({'loss': total_loss})
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else:
raise ValueError(f"Unsupported model {self.arch}")
@register
class DistillYOLOv3Loss(nn.Layer):
def __init__(self, weight=1000):
super(DistillYOLOv3Loss, self).__init__()
self.weight = weight
def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
loss_w = paddle.abs(sw - tw)
loss_h = paddle.abs(sh - th)
loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
return weighted_loss
def obj_weighted_cls(self, scls, tcls, tobj):
loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
return weighted_loss
def obj_loss(self, sobj, tobj):
obj_mask = paddle.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
loss = paddle.mean(
ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
student_distill_pairs = student_model.yolo_head.loss.distill_pairs
distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
distill_reg_loss.append(
self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
distill_cls_loss.append(
self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
distill_reg_loss = paddle.add_n(distill_reg_loss)
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_obj_loss = paddle.add_n(distill_obj_loss)
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
) * self.weight
return loss
def parameter_init(mode="kaiming", value=0.):
if mode == "kaiming":
weight_attr = paddle.nn.initializer.KaimingUniform()
elif mode == "constant":
weight_attr = paddle.nn.initializer.Constant(value=value)
else:
weight_attr = paddle.nn.initializer.KaimingUniform()
weight_init = ParamAttr(initializer=weight_attr)
return weight_init
@register
class FGDFeatureLoss(nn.Layer):
"""
The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
Paddle version of `Focal and Global Knowledge Distillation for Detectors`
Args:
student_channels(int): The number of channels in the student's FPN feature map. Default to 256.
teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256.
temp (float, optional): The temperature coefficient. Defaults to 0.5.
alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005
gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001
lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005
"""
def __init__(self,
student_channels=256,
teacher_channels=256,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005):
super(FGDFeatureLoss, self).__init__()
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd
kaiming_init = parameter_init("kaiming")
zeros_init = parameter_init("constant", 0.0)
if student_channels != teacher_channels:
self.align = nn.Conv2D(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=kaiming_init)
student_channels = teacher_channels
else:
self.align = None
self.conv_mask_s = nn.Conv2D(
student_channels, 1, kernel_size=1, weight_attr=kaiming_init)
self.conv_mask_t = nn.Conv2D(
teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
self.stu_conv_block = nn.Sequential(
nn.Conv2D(
student_channels,
student_channels // 2,
kernel_size=1,
weight_attr=zeros_init),
nn.LayerNorm([student_channels // 2, 1, 1]),
nn.ReLU(),
nn.Conv2D(
student_channels // 2,
student_channels,
kernel_size=1,
weight_attr=zeros_init))
self.tea_conv_block = nn.Sequential(
nn.Conv2D(
teacher_channels,
teacher_channels // 2,
kernel_size=1,
weight_attr=zeros_init),
nn.LayerNorm([teacher_channels // 2, 1, 1]),
nn.ReLU(),
nn.Conv2D(
teacher_channels // 2,
teacher_channels,
kernel_size=1,
weight_attr=zeros_init))
def spatial_channel_attention(self, x, t=0.5):
shape = paddle.shape(x)
N, C, H, W = shape
_f = paddle.abs(x)
spatial_map = paddle.reshape(
paddle.mean(
_f, axis=1, keepdim=True) / t, [N, -1])
spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W
spatial_att = paddle.reshape(spatial_map, [N, H, W])
channel_map = paddle.mean(
paddle.mean(
_f, axis=2, keepdim=False), axis=2, keepdim=False)
channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C
return [spatial_att, channel_att]
def spatial_pool(self, x, mode="teacher"):
batch, channel, width, height = x.shape
x_copy = x
x_copy = paddle.reshape(x_copy, [batch, channel, height * width])
x_copy = x_copy.unsqueeze(1)
if mode.lower() == "student":
context_mask = self.conv_mask_s(x)
else:
context_mask = self.conv_mask_t(x)
context_mask = paddle.reshape(context_mask, [batch, 1, height * width])
context_mask = F.softmax(context_mask, axis=2)
context_mask = context_mask.unsqueeze(-1)
context = paddle.matmul(x_copy, context_mask)
context = paddle.reshape(context, [batch, channel, 1, 1])
return context
def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att,
tea_spatial_att):
def _func(a, b):
return paddle.sum(paddle.abs(a - b)) / len(a)
mask_loss = _func(stu_channel_att, tea_channel_att) + _func(
stu_spatial_att, tea_spatial_att)
return mask_loss
def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg,
tea_channel_att, tea_spatial_att):
Mask_fg = Mask_fg.unsqueeze(axis=1)
Mask_bg = Mask_bg.unsqueeze(axis=1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_spatial_att = tea_spatial_att.unsqueeze(axis=1)
fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att))
fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att))
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg))
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg))
fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att))
fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att))
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg))
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg))
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg)
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg)
return fg_loss, bg_loss
def relation_loss(self, stu_feature, tea_feature):
context_s = self.spatial_pool(stu_feature, "student")
context_t = self.spatial_pool(tea_feature, "teacher")
out_s = stu_feature + self.stu_conv_block(context_s)
out_t = tea_feature + self.tea_conv_block(context_t)
rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)
return rela_loss
def mask_value(self, mask, xl, xr, yl, yr, value):
mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value)
return mask
def forward(self, stu_feature, tea_feature, inputs):
"""Forward function.
Args:
stu_feature(Tensor): Bs*C*H*W, student's feature map
tea_feature(Tensor): Bs*C*H*W, teacher's feature map
inputs: The inputs with gt bbox and input shape info.
"""
assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \
f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.'
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys(
), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs."
gt_bboxes = inputs['gt_bbox']
ins_shape = [
inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0])
]
index_gt = []
for i in range(len(gt_bboxes)):
if gt_bboxes[i].size > 2:
index_gt.append(i)
# only distill feature with labeled GTbox
if len(index_gt) != len(gt_bboxes):
index_gt_t = paddle.to_tensor(index_gt)
preds_S = paddle.index_select(preds_S, index_gt_t)
preds_T = paddle.index_select(preds_T, index_gt_t)
ins_shape = [ins_shape[c] for c in index_gt]
gt_bboxes = [gt_bboxes[c] for c in index_gt]
assert len(gt_bboxes) == preds_T.shape[
0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]."
if self.align is not None:
stu_feature = self.align(stu_feature)
N, C, H, W = stu_feature.shape
tea_spatial_att, tea_channel_att = self.spatial_channel_attention(
tea_feature, self.temp)
stu_spatial_att, stu_channel_att = self.spatial_channel_attention(
stu_feature, self.temp)
Mask_fg = paddle.zeros(tea_spatial_att.shape)
Mask_bg = paddle.ones_like(tea_spatial_att)
one_tmp = paddle.ones([*tea_spatial_att.shape[1:]])
zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]])
Mask_fg.stop_gradient = True
Mask_bg.stop_gradient = True
one_tmp.stop_gradient = True
zero_tmp.stop_gradient = True
wmin, wmax, hmin, hmax, area = [], [], [], [], []
for i in range(N):
tmp_box = paddle.ones_like(gt_bboxes[i])
tmp_box.stop_gradient = True
tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W
tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W
tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H
tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H
zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32")
ones = paddle.ones_like(tmp_box[:, 2], dtype="int32")
zero.stop_gradient = True
ones.stop_gradient = True
wmin.append(
paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero))
wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32"))
hmin.append(
paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero))
hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32"))
area_recip = 1.0 / (
hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / (
wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
for j in range(len(gt_bboxes[i])):
Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j],
hmax[i][j] + 1, wmin[i][j],
wmax[i][j] + 1, area_recip[0][j])
Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
if paddle.sum(Mask_bg[i]):
Mask_bg[i] /= paddle.sum(Mask_bg[i])
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg,
Mask_bg, tea_channel_att,
tea_spatial_att)
mask_loss = self.mask_loss(stu_channel_att, tea_channel_att,
stu_spatial_att, tea_spatial_att)
rela_loss = self.relation_loss(stu_feature, tea_feature)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss