|
|
|
@ -57,7 +57,7 @@ def reorder_cls_channel(val, num_classes=81): |
|
|
|
|
# fc_cls |
|
|
|
|
elif out_channels == num_classes: |
|
|
|
|
new_val = torch.cat((val[1:], val[:1]), dim=0) |
|
|
|
|
# agnostic | retina_cls | rpn_cls |
|
|
|
|
# agnostic | retina_cls |
|
|
|
|
else: |
|
|
|
|
new_val = val |
|
|
|
|
|
|
|
|
@ -89,7 +89,7 @@ def truncate_cls_channel(val, num_classes=81): |
|
|
|
|
def truncate_reg_channel(val, num_classes=81): |
|
|
|
|
# bias |
|
|
|
|
if val.dim() == 1: |
|
|
|
|
# fc_reg|rpn_reg |
|
|
|
|
# fc_reg |
|
|
|
|
if val.size(0) % num_classes == 0: |
|
|
|
|
new_val = val.reshape(num_classes, -1)[:num_classes - 1] |
|
|
|
|
new_val = new_val.reshape(-1) |
|
|
|
@ -99,7 +99,7 @@ def truncate_reg_channel(val, num_classes=81): |
|
|
|
|
# weight |
|
|
|
|
else: |
|
|
|
|
out_channels, in_channels = val.shape[:2] |
|
|
|
|
# fc_reg|rpn_reg |
|
|
|
|
# fc_reg |
|
|
|
|
if out_channels % num_classes == 0: |
|
|
|
|
new_val = val.reshape(num_classes, -1, in_channels, |
|
|
|
|
*val.shape[2:])[1:] |
|
|
|
@ -137,14 +137,14 @@ def convert(in_file, out_file, num_classes): |
|
|
|
|
|
|
|
|
|
# classification |
|
|
|
|
m = re.search( |
|
|
|
|
r'(conv_cls|retina_cls|rpn_cls|fc_cls|fcos_cls|' |
|
|
|
|
r'(conv_cls|retina_cls|fc_cls|fcos_cls|' |
|
|
|
|
r'fovea_cls).(weight|bias)', new_key) |
|
|
|
|
if m is not None: |
|
|
|
|
print(f'reorder cls channels of {new_key}') |
|
|
|
|
new_val = reorder_cls_channel(val, num_classes) |
|
|
|
|
|
|
|
|
|
# regression |
|
|
|
|
m = re.search(r'(fc_reg|rpn_reg).(weight|bias)', new_key) |
|
|
|
|
m = re.search(r'(fc_reg).(weight|bias)', new_key) |
|
|
|
|
if m is not None and not reg_cls_agnostic: |
|
|
|
|
print(f'truncate regression channels of {new_key}') |
|
|
|
|
new_val = truncate_reg_channel(val, num_classes) |
|
|
|
|