|
|
|
@ -6,7 +6,7 @@ import torch |
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear, |
|
|
|
|
MaxPool2d) |
|
|
|
|
MaxPool2d, MaxPool3d) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@patch('torch.__version__', '1.1') |
|
|
|
@ -186,6 +186,32 @@ def test_max_pool_2d(): |
|
|
|
|
assert torch.equal(wrapper(x_normal), ref_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@patch('torch.__version__', '1.1') |
|
|
|
|
def test_max_pool_3d(): |
|
|
|
|
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), |
|
|
|
|
('in_t', [10, 20]), ('in_channel', [1, 3]), |
|
|
|
|
('out_channel', [1, 3]), ('kernel_size', [3, 5]), |
|
|
|
|
('stride', [1, 2]), ('padding', [0, 1]), |
|
|
|
|
('dilation', [1, 2])]) |
|
|
|
|
|
|
|
|
|
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product( |
|
|
|
|
*list(test_cases.values())): |
|
|
|
|
# wrapper op with 0-dim input |
|
|
|
|
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True) |
|
|
|
|
wrapper = MaxPool3d(k, stride=s, padding=p, dilation=d) |
|
|
|
|
wrapper_out = wrapper(x_empty) |
|
|
|
|
|
|
|
|
|
# torch op with 3-dim input as shape reference |
|
|
|
|
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w) |
|
|
|
|
ref = nn.MaxPool3d(k, stride=s, padding=p, dilation=d) |
|
|
|
|
ref_out = ref(x_normal) |
|
|
|
|
|
|
|
|
|
assert wrapper_out.shape[0] == 0 |
|
|
|
|
assert wrapper_out.shape[1:] == ref_out.shape[1:] |
|
|
|
|
|
|
|
|
|
assert torch.equal(wrapper(x_normal), ref_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@patch('torch.__version__', '1.1') |
|
|
|
|
def test_linear(): |
|
|
|
|
test_cases = OrderedDict([ |
|
|
|
|