OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.5 KiB
72 lines
2.5 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
from torch.autograd import Function |
|
from torch.nn import functional as F |
|
|
|
|
|
class SigmoidGeometricMean(Function): |
|
"""Forward and backward function of geometric mean of two sigmoid |
|
functions. |
|
|
|
This implementation with analytical gradient function substitutes |
|
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The |
|
original implementation incurs none during gradient backprapagation |
|
if both x and y are very small values. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x, y): |
|
x_sigmoid = x.sigmoid() |
|
y_sigmoid = y.sigmoid() |
|
z = (x_sigmoid * y_sigmoid).sqrt() |
|
ctx.save_for_backward(x_sigmoid, y_sigmoid, z) |
|
return z |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
x_sigmoid, y_sigmoid, z = ctx.saved_tensors |
|
grad_x = grad_output * z * (1 - x_sigmoid) / 2 |
|
grad_y = grad_output * z * (1 - y_sigmoid) / 2 |
|
return grad_x, grad_y |
|
|
|
|
|
sigmoid_geometric_mean = SigmoidGeometricMean.apply |
|
|
|
|
|
def interpolate_as(source, target, mode='bilinear', align_corners=False): |
|
"""Interpolate the `source` to the shape of the `target`. |
|
|
|
The `source` must be a Tensor, but the `target` can be a Tensor or a |
|
np.ndarray with the shape (..., target_h, target_w). |
|
|
|
Args: |
|
source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or |
|
(N, C, H, W). |
|
target (Tensor | np.ndarray): The interpolation target with the shape |
|
(..., target_h, target_w). |
|
mode (str): Algorithm used for interpolation. The options are the |
|
same as those in F.interpolate(). Default: ``'bilinear'``. |
|
align_corners (bool): The same as the argument in F.interpolate(). |
|
|
|
Returns: |
|
Tensor: The interpolated source Tensor. |
|
""" |
|
assert len(target.shape) >= 2 |
|
|
|
def _interpolate_as(source, target, mode='bilinear', align_corners=False): |
|
"""Interpolate the `source` (4D) to the shape of the `target`.""" |
|
target_h, target_w = target.shape[-2:] |
|
source_h, source_w = source.shape[-2:] |
|
if target_h != source_h or target_w != source_w: |
|
source = F.interpolate( |
|
source, |
|
size=(target_h, target_w), |
|
mode=mode, |
|
align_corners=align_corners) |
|
return source |
|
|
|
if len(source.shape) == 3: |
|
source = source[:, None, :, :] |
|
source = _interpolate_as(source, target, mode, align_corners) |
|
return source[:, 0, :, :] |
|
else: |
|
return _interpolate_as(source, target, mode, align_corners)
|
|
|