You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
2.1 KiB
62 lines
2.1 KiB
import argparse |
|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description="FCOS Detectron2 Converter") |
|
parser.add_argument( |
|
"--model", |
|
default="weights/fcos_R_50_1x_official.pth", |
|
metavar="FILE", |
|
help="path to model weights", |
|
) |
|
parser.add_argument( |
|
"--output", |
|
default="weights/fcos_R_50_1x_converted.pth", |
|
metavar="FILE", |
|
help="path to model weights", |
|
) |
|
return parser |
|
|
|
|
|
def rename_resnet_param_names(ckpt_state_dict): |
|
converted_state_dict = OrderedDict() |
|
for key in ckpt_state_dict.keys(): |
|
value = ckpt_state_dict[key] |
|
|
|
key = key.replace("module.", "") |
|
key = key.replace("body", "bottom_up") |
|
|
|
# adding a . ahead to avoid renaming the fpn modules |
|
# this can happen after fpn renaming |
|
key = key.replace(".layer1", ".res2") |
|
key = key.replace(".layer2", ".res3") |
|
key = key.replace(".layer3", ".res4") |
|
key = key.replace(".layer4", ".res5") |
|
key = key.replace("downsample.0", "shortcut") |
|
key = key.replace("downsample.1", "shortcut.norm") |
|
key = key.replace("bn1", "conv1.norm") |
|
key = key.replace("bn2", "conv2.norm") |
|
key = key.replace("bn3", "conv3.norm") |
|
key = key.replace("fpn_inner2", "fpn_lateral3") |
|
key = key.replace("fpn_inner3", "fpn_lateral4") |
|
key = key.replace("fpn_inner4", "fpn_lateral5") |
|
key = key.replace("fpn_layer2", "fpn_output3") |
|
key = key.replace("fpn_layer3", "fpn_output4") |
|
key = key.replace("fpn_layer4", "fpn_output5") |
|
key = key.replace("top_blocks", "top_block") |
|
key = key.replace("fpn.", "") |
|
key = key.replace("rpn", "proposal_generator") |
|
key = key.replace("head", "fcos_head") |
|
|
|
converted_state_dict[key] = value |
|
return converted_state_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_parser().parse_args() |
|
ckpt = torch.load(args.model) |
|
model = rename_resnet_param_names(ckpt["model"]) |
|
torch.save(model, args.output)
|
|
|