OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.8 KiB
51 lines
1.8 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version |
|
|
|
if torch.__version__ == 'parrots': |
|
TORCH_VERSION = torch.__version__ |
|
else: |
|
# torch.__version__ could be 1.3.1+cu92, we only need the first two |
|
# for comparison |
|
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) |
|
|
|
|
|
def adaptive_avg_pool2d(input, output_size): |
|
"""Handle empty batch dimension to adaptive_avg_pool2d. |
|
|
|
Args: |
|
input (tensor): 4D tensor. |
|
output_size (int, tuple[int,int]): the target output size. |
|
""" |
|
if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): |
|
if isinstance(output_size, int): |
|
output_size = [output_size, output_size] |
|
output_size = [*input.shape[:2], *output_size] |
|
empty = NewEmptyTensorOp.apply(input, output_size) |
|
return empty |
|
else: |
|
return F.adaptive_avg_pool2d(input, output_size) |
|
|
|
|
|
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): |
|
"""Handle empty batch dimension to AdaptiveAvgPool2d.""" |
|
|
|
def forward(self, x): |
|
# PyTorch 1.9 does not support empty tensor inference yet |
|
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): |
|
output_size = self.output_size |
|
if isinstance(output_size, int): |
|
output_size = [output_size, output_size] |
|
else: |
|
output_size = [ |
|
v if v is not None else d |
|
for v, d in zip(output_size, |
|
x.size()[-2:]) |
|
] |
|
output_size = [*x.shape[:2], *output_size] |
|
empty = NewEmptyTensorOp.apply(x, output_size) |
|
return empty |
|
|
|
return super().forward(x)
|
|
|