diff --git a/tools/upgrade_model_version.py b/tools/upgrade_model_version.py index 5d17d59e9..68d2fb2b5 100644 --- a/tools/upgrade_model_version.py +++ b/tools/upgrade_model_version.py @@ -57,7 +57,7 @@ def reorder_cls_channel(val, num_classes=81): # fc_cls elif out_channels == num_classes: new_val = torch.cat((val[1:], val[:1]), dim=0) - # agnostic | retina_cls | rpn_cls + # agnostic | retina_cls else: new_val = val @@ -89,7 +89,7 @@ def truncate_cls_channel(val, num_classes=81): def truncate_reg_channel(val, num_classes=81): # bias if val.dim() == 1: - # fc_reg|rpn_reg + # fc_reg if val.size(0) % num_classes == 0: new_val = val.reshape(num_classes, -1)[:num_classes - 1] new_val = new_val.reshape(-1) @@ -99,7 +99,7 @@ def truncate_reg_channel(val, num_classes=81): # weight else: out_channels, in_channels = val.shape[:2] - # fc_reg|rpn_reg + # fc_reg if out_channels % num_classes == 0: new_val = val.reshape(num_classes, -1, in_channels, *val.shape[2:])[1:] @@ -137,14 +137,14 @@ def convert(in_file, out_file, num_classes): # classification m = re.search( - r'(conv_cls|retina_cls|rpn_cls|fc_cls|fcos_cls|' + r'(conv_cls|retina_cls|fc_cls|fcos_cls|' r'fovea_cls).(weight|bias)', new_key) if m is not None: print(f'reorder cls channels of {new_key}') new_val = reorder_cls_channel(val, num_classes) # regression - m = re.search(r'(fc_reg|rpn_reg).(weight|bias)', new_key) + m = re.search(r'(fc_reg).(weight|bias)', new_key) if m is not None and not reg_cls_agnostic: print(f'truncate regression channels of {new_key}') new_val = truncate_reg_channel(val, num_classes)