add MaxPool3d

pull/652/head
dreamerlin 5 years ago
parent 8ccea20234
commit 86d9f4684a
  1. 7
      mmcv/cnn/__init__.py
  2. 4
      mmcv/cnn/bricks/__init__.py
  3. 21
      mmcv/cnn/bricks/wrappers.py
  4. 28
      tests/test_cnn/test_wrappers.py

@ -6,8 +6,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule, ContextBlock, Conv2d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvTranspose3d, ConvWS2d, ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention, DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, NonLocal1d, HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
NonLocal2d, NonLocal3d, Scale, Swish, NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer, build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer, build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
@ -29,5 +29,6 @@ __all__ = [
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d' 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d'
] ]

@ -18,7 +18,7 @@ from .scale import Scale
from .swish import Swish from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear, from .wrappers import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d) MaxPool2d, MaxPool3d)
__all__ = [ __all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer', 'ConvModule', 'build_activation_layer', 'build_conv_layer',
@ -29,5 +29,5 @@ __all__ = [
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d' 'ConvTranspose3d', 'MaxPool3d'
] ]

@ -8,7 +8,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
@ -122,6 +122,25 @@ class MaxPool2d(nn.MaxPool2d):
return super().forward(x) return super().forward(x)
class MaxPool3d(nn.MaxPool3d):
def forward(self, x):
# PyTorch 1.7 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
_triple(self.stride),
_triple(self.dilation)):
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
o = math.ceil(o) if self.ceil_mode else math.floor(o)
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
return empty
return super().forward(x)
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
def forward(self, x): def forward(self, x):

@ -6,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear, from mmcv.cnn.bricks import (Conv2d, ConvTranspose2d, ConvTranspose3d, Linear,
MaxPool2d) MaxPool2d, MaxPool3d)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
@ -186,6 +186,32 @@ def test_max_pool_2d():
assert torch.equal(wrapper(x_normal), ref_out) assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1')
def test_max_pool_3d():
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
('in_t', [10, 20]), ('in_channel', [1, 3]),
('out_channel', [1, 3]), ('kernel_size', [3, 5]),
('stride', [1, 2]), ('padding', [0, 1]),
('dilation', [1, 2])])
for in_h, in_w, in_t, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_t, in_h, in_w, requires_grad=True)
wrapper = MaxPool3d(k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_cha, in_t, in_h, in_w)
ref = nn.MaxPool3d(k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_linear(): def test_linear():
test_cases = OrderedDict([ test_cases = OrderedDict([

Loading…
Cancel
Save