|
|
|
@ -311,45 +311,37 @@ def test_resnest_stem(): |
|
|
|
|
assert model.conv1.out_channels == 64 |
|
|
|
|
assert model.norm1.num_features == 64 |
|
|
|
|
|
|
|
|
|
# Test default stem_channels, with base_channels=32 |
|
|
|
|
model = ResNet(50, base_channels=32) |
|
|
|
|
assert model.stem_channels == 32 |
|
|
|
|
assert model.conv1.out_channels == 32 |
|
|
|
|
assert model.norm1.num_features == 32 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 32 |
|
|
|
|
|
|
|
|
|
# Test stem_channels=64 |
|
|
|
|
model = ResNet(50, stem_channels=64) |
|
|
|
|
assert model.stem_channels == 64 |
|
|
|
|
assert model.conv1.out_channels == 64 |
|
|
|
|
assert model.norm1.num_features == 64 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 64 |
|
|
|
|
|
|
|
|
|
# Test stem_channels=64, with base_channels=32 |
|
|
|
|
model = ResNet(50, stem_channels=64, base_channels=32) |
|
|
|
|
assert model.stem_channels == 64 |
|
|
|
|
assert model.conv1.out_channels == 64 |
|
|
|
|
assert model.norm1.num_features == 64 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 64 |
|
|
|
|
|
|
|
|
|
# Test stem_channels=128 |
|
|
|
|
model = ResNet(depth=50, stem_channels=128) |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
assert model.conv1.out_channels == 128 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 128 |
|
|
|
|
# Test default stem_channels, with base_channels=3 |
|
|
|
|
model = ResNet(50, base_channels=3) |
|
|
|
|
assert model.stem_channels == 3 |
|
|
|
|
assert model.conv1.out_channels == 3 |
|
|
|
|
assert model.norm1.num_features == 3 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 3 |
|
|
|
|
|
|
|
|
|
# Test stem_channels=3 |
|
|
|
|
model = ResNet(50, stem_channels=3) |
|
|
|
|
assert model.stem_channels == 3 |
|
|
|
|
assert model.conv1.out_channels == 3 |
|
|
|
|
assert model.norm1.num_features == 3 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 3 |
|
|
|
|
|
|
|
|
|
# Test stem_channels=3, with base_channels=2 |
|
|
|
|
model = ResNet(50, stem_channels=3, base_channels=2) |
|
|
|
|
assert model.stem_channels == 3 |
|
|
|
|
assert model.conv1.out_channels == 3 |
|
|
|
|
assert model.norm1.num_features == 3 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 3 |
|
|
|
|
|
|
|
|
|
# Test V1d stem_channels |
|
|
|
|
model = ResNetV1d(depth=50, stem_channels=128) |
|
|
|
|
model.init_weights() |
|
|
|
|
model = ResNetV1d(depth=50, stem_channels=6) |
|
|
|
|
model.train() |
|
|
|
|
assert model.stem[0].out_channels == 64 |
|
|
|
|
assert model.stem[1].num_features == 64 |
|
|
|
|
assert model.stem[3].out_channels == 64 |
|
|
|
|
assert model.stem[4].num_features == 64 |
|
|
|
|
assert model.stem[6].out_channels == 128 |
|
|
|
|
assert model.stem[7].num_features == 128 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 128 |
|
|
|
|
assert model.stem[0].out_channels == 3 |
|
|
|
|
assert model.stem[1].num_features == 3 |
|
|
|
|
assert model.stem[3].out_channels == 3 |
|
|
|
|
assert model.stem[4].num_features == 3 |
|
|
|
|
assert model.stem[6].out_channels == 6 |
|
|
|
|
assert model.stem[7].num_features == 6 |
|
|
|
|
assert model.layer1[0].conv1.in_channels == 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_resnet_backbone(): |
|
|
|
@ -388,29 +380,25 @@ def test_resnet_backbone(): |
|
|
|
|
with pytest.raises(TypeError): |
|
|
|
|
# pretrained must be a string path |
|
|
|
|
model = ResNet(50, pretrained=0) |
|
|
|
|
model.init_weights() |
|
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
# Style must be in ['pytorch', 'caffe'] |
|
|
|
|
ResNet(50, style='tensorflow') |
|
|
|
|
|
|
|
|
|
# Test ResNet50 norm_eval=True |
|
|
|
|
model = ResNet(50, norm_eval=True) |
|
|
|
|
model.init_weights() |
|
|
|
|
model = ResNet(50, norm_eval=True, base_channels=1) |
|
|
|
|
model.train() |
|
|
|
|
assert check_norm_state(model.modules(), False) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with torchvision pretrained weight |
|
|
|
|
model = ResNet( |
|
|
|
|
depth=50, norm_eval=True, pretrained='torchvision://resnet50') |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
assert check_norm_state(model.modules(), False) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with first stage frozen |
|
|
|
|
frozen_stages = 1 |
|
|
|
|
model = ResNet(50, frozen_stages=frozen_stages) |
|
|
|
|
model.init_weights() |
|
|
|
|
model = ResNet(50, frozen_stages=frozen_stages, base_channels=1) |
|
|
|
|
model.train() |
|
|
|
|
assert model.norm1.training is False |
|
|
|
|
for layer in [model.conv1, model.norm1]: |
|
|
|
@ -425,9 +413,8 @@ def test_resnet_backbone(): |
|
|
|
|
assert param.requires_grad is False |
|
|
|
|
|
|
|
|
|
# Test ResNet50V1d with first stage frozen |
|
|
|
|
model = ResNetV1d(depth=50, frozen_stages=frozen_stages) |
|
|
|
|
model = ResNetV1d(depth=50, frozen_stages=frozen_stages, base_channels=2) |
|
|
|
|
assert len(model.stem) == 9 |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
assert check_norm_state(model.stem, False) |
|
|
|
|
for param in model.stem.parameters(): |
|
|
|
@ -442,16 +429,15 @@ def test_resnet_backbone(): |
|
|
|
|
|
|
|
|
|
# Test ResNet18 forward |
|
|
|
|
model = ResNet(18) |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 64, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 128, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 256, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 512, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 64, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 128, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 256, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 512, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNet18 with checkpoint forward |
|
|
|
|
model = ResNet(18, with_cp=True) |
|
|
|
@ -460,65 +446,63 @@ def test_resnet_backbone(): |
|
|
|
|
assert m.with_cp |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with BatchNorm forward |
|
|
|
|
model = ResNet(50) |
|
|
|
|
model = ResNet(50, base_channels=1) |
|
|
|
|
for m in model.modules(): |
|
|
|
|
if is_norm(m): |
|
|
|
|
assert isinstance(m, _BatchNorm) |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 4, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 8, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 16, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 32, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with layers 1, 2, 3 out forward |
|
|
|
|
model = ResNet(50, out_indices=(0, 1, 2)) |
|
|
|
|
model.init_weights() |
|
|
|
|
model = ResNet(50, out_indices=(0, 1, 2), base_channels=1) |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 3 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 4, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 8, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 16, 2, 2]) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with checkpoint forward |
|
|
|
|
model = ResNet(50, with_cp=True) |
|
|
|
|
model = ResNet(50, with_cp=True, base_channels=1) |
|
|
|
|
for m in model.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert m.with_cp |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 4, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 8, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 16, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 32, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with GroupNorm forward |
|
|
|
|
model = ResNet( |
|
|
|
|
50, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)) |
|
|
|
|
50, |
|
|
|
|
base_channels=4, |
|
|
|
|
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True)) |
|
|
|
|
for m in model.modules(): |
|
|
|
|
if is_norm(m): |
|
|
|
|
assert isinstance(m, GroupNorm) |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 16, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 32, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 64, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 128, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with 1 GeneralizedAttention after conv2, 1 NonLocal2D |
|
|
|
|
# after conv2, 1 ContextBlock after conv3 in layers 2, 3, 4 |
|
|
|
@ -538,39 +522,38 @@ def test_resnet_backbone(): |
|
|
|
|
stages=(False, True, True, False), |
|
|
|
|
position='after_conv3') |
|
|
|
|
] |
|
|
|
|
model = ResNet(50, plugins=plugins) |
|
|
|
|
model = ResNet(50, plugins=plugins, base_channels=8) |
|
|
|
|
for m in model.layer1.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert not hasattr(m, 'context_block') |
|
|
|
|
assert not hasattr(m, 'gen_attention_block') |
|
|
|
|
assert m.nonlocal_block.in_channels == 64 |
|
|
|
|
assert m.nonlocal_block.in_channels == 8 |
|
|
|
|
for m in model.layer2.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert m.nonlocal_block.in_channels == 128 |
|
|
|
|
assert m.gen_attention_block.in_channels == 128 |
|
|
|
|
assert m.context_block.in_channels == 512 |
|
|
|
|
assert m.nonlocal_block.in_channels == 16 |
|
|
|
|
assert m.gen_attention_block.in_channels == 16 |
|
|
|
|
assert m.context_block.in_channels == 64 |
|
|
|
|
|
|
|
|
|
for m in model.layer3.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert m.nonlocal_block.in_channels == 256 |
|
|
|
|
assert m.gen_attention_block.in_channels == 256 |
|
|
|
|
assert m.context_block.in_channels == 1024 |
|
|
|
|
assert m.nonlocal_block.in_channels == 32 |
|
|
|
|
assert m.gen_attention_block.in_channels == 32 |
|
|
|
|
assert m.context_block.in_channels == 128 |
|
|
|
|
|
|
|
|
|
for m in model.layer4.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert m.nonlocal_block.in_channels == 512 |
|
|
|
|
assert m.gen_attention_block.in_channels == 512 |
|
|
|
|
assert m.nonlocal_block.in_channels == 64 |
|
|
|
|
assert m.gen_attention_block.in_channels == 64 |
|
|
|
|
assert not hasattr(m, 'context_block') |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 32, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 64, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 128, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 256, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 with 1 ContextBlock after conv2, 1 ContextBlock after |
|
|
|
|
# conv3 in layers 2, 3, 4 |
|
|
|
@ -585,7 +568,7 @@ def test_resnet_backbone(): |
|
|
|
|
position='after_conv3') |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
model = ResNet(50, plugins=plugins) |
|
|
|
|
model = ResNet(50, plugins=plugins, base_channels=8) |
|
|
|
|
for m in model.layer1.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert not hasattr(m, 'context_block') |
|
|
|
@ -594,33 +577,32 @@ def test_resnet_backbone(): |
|
|
|
|
for m in model.layer2.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert not hasattr(m, 'context_block') |
|
|
|
|
assert m.context_block1.in_channels == 512 |
|
|
|
|
assert m.context_block2.in_channels == 512 |
|
|
|
|
assert m.context_block1.in_channels == 64 |
|
|
|
|
assert m.context_block2.in_channels == 64 |
|
|
|
|
|
|
|
|
|
for m in model.layer3.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert not hasattr(m, 'context_block') |
|
|
|
|
assert m.context_block1.in_channels == 1024 |
|
|
|
|
assert m.context_block2.in_channels == 1024 |
|
|
|
|
assert m.context_block1.in_channels == 128 |
|
|
|
|
assert m.context_block2.in_channels == 128 |
|
|
|
|
|
|
|
|
|
for m in model.layer4.modules(): |
|
|
|
|
if is_block(m): |
|
|
|
|
assert not hasattr(m, 'context_block') |
|
|
|
|
assert not hasattr(m, 'context_block1') |
|
|
|
|
assert not hasattr(m, 'context_block2') |
|
|
|
|
model.init_weights() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 32, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 64, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 128, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 256, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNet50 zero initialization of residual |
|
|
|
|
model = ResNet(50, zero_init_residual=True) |
|
|
|
|
model = ResNet(50, zero_init_residual=True, base_channels=1) |
|
|
|
|
model.init_weights() |
|
|
|
|
for m in model.modules(): |
|
|
|
|
if isinstance(m, Bottleneck): |
|
|
|
@ -629,39 +611,22 @@ def test_resnet_backbone(): |
|
|
|
|
assert assert_params_all_zeros(m.norm2) |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 4, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 8, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 16, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 32, 1, 1]) |
|
|
|
|
|
|
|
|
|
# Test ResNetV1d forward |
|
|
|
|
model = ResNetV1d(depth=50) |
|
|
|
|
model.init_weights() |
|
|
|
|
model = ResNetV1d(depth=50, base_channels=2) |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
|
|
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
|
|
|
feat = model(imgs) |
|
|
|
|
assert len(feat) == 4 |
|
|
|
|
assert feat[0].shape == torch.Size([1, 256, 56, 56]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 512, 28, 28]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 1024, 14, 14]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 2048, 7, 7]) |
|
|
|
|
assert feat[0].shape == torch.Size([1, 8, 8, 8]) |
|
|
|
|
assert feat[1].shape == torch.Size([1, 16, 4, 4]) |
|
|
|
|
assert feat[2].shape == torch.Size([1, 32, 2, 2]) |
|
|
|
|
assert feat[3].shape == torch.Size([1, 64, 1, 1]) |
|
|
|
|