You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
216 lines
8.2 KiB
216 lines
8.2 KiB
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model |
|
# Users should be careful about adopting these functions in any commercial matters. |
|
# https://github.com/AliaksandrSiarohin/first-order-model/blob/master/LICENSE.md |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
from .first_order import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian |
|
|
|
|
|
class DenseMotionNetwork(nn.Layer): |
|
""" |
|
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving |
|
""" |
|
|
|
def __init__(self, |
|
block_expansion, |
|
num_blocks, |
|
max_features, |
|
num_kp, |
|
num_channels, |
|
estimate_occlusion_map=False, |
|
scale_factor=1, |
|
kp_variance=0.01, |
|
mobile_net=False): |
|
super(DenseMotionNetwork, self).__init__() |
|
self.hourglass = Hourglass( |
|
block_expansion=block_expansion, |
|
in_features=(num_kp + 1) * (num_channels + 1), |
|
max_features=max_features, |
|
num_blocks=num_blocks, |
|
mobile_net=mobile_net) |
|
|
|
if mobile_net: |
|
self.mask = nn.Sequential( |
|
nn.Conv2D( |
|
self.hourglass.out_filters, |
|
self.hourglass.out_filters, |
|
kernel_size=3, |
|
weight_attr=nn.initializer.KaimingUniform(), |
|
padding=1), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
self.hourglass.out_filters, |
|
self.hourglass.out_filters, |
|
kernel_size=3, |
|
weight_attr=nn.initializer.KaimingUniform(), |
|
padding=1), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
self.hourglass.out_filters, |
|
num_kp + 1, |
|
kernel_size=3, |
|
weight_attr=nn.initializer.KaimingUniform(), |
|
padding=1)) |
|
else: |
|
self.mask = nn.Conv2D( |
|
self.hourglass.out_filters, |
|
num_kp + 1, |
|
kernel_size=(7, 7), |
|
padding=(3, 3)) |
|
|
|
if estimate_occlusion_map: |
|
if mobile_net: |
|
self.occlusion = nn.Sequential( |
|
nn.Conv2D( |
|
self.hourglass.out_filters, |
|
self.hourglass.out_filters, |
|
kernel_size=3, |
|
padding=1, |
|
weight_attr=nn.initializer.KaimingUniform()), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
self.hourglass.out_filters, |
|
self.hourglass.out_filters, |
|
kernel_size=3, |
|
weight_attr=nn.initializer.KaimingUniform(), |
|
padding=1), |
|
nn.ReLU(), |
|
nn.Conv2D( |
|
self.hourglass.out_filters, |
|
1, |
|
kernel_size=3, |
|
padding=1, |
|
weight_attr=nn.initializer.KaimingUniform())) |
|
else: |
|
self.occlusion = nn.Conv2D( |
|
self.hourglass.out_filters, |
|
1, |
|
kernel_size=(7, 7), |
|
padding=(3, 3)) |
|
else: |
|
self.occlusion = None |
|
|
|
self.num_kp = num_kp |
|
self.scale_factor = scale_factor |
|
self.kp_variance = kp_variance |
|
|
|
if self.scale_factor != 1: |
|
self.down = AntiAliasInterpolation2d( |
|
num_channels, self.scale_factor, mobile_net=mobile_net) |
|
|
|
def create_heatmap_representations(self, source_image, kp_driving, |
|
kp_source): |
|
""" |
|
Eq 6. in the paper H_k(z) |
|
""" |
|
spatial_size = source_image.shape[2:] |
|
gaussian_driving = kp2gaussian( |
|
kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) |
|
gaussian_source = kp2gaussian( |
|
kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) |
|
heatmap = gaussian_driving - gaussian_source |
|
|
|
#adding background feature |
|
zeros = paddle.zeros( |
|
[heatmap.shape[0], 1, spatial_size[0], spatial_size[1]], |
|
heatmap.dtype) #.type(heatmap.type()) |
|
heatmap = paddle.concat([zeros, heatmap], axis=1) |
|
heatmap = heatmap.unsqueeze(2) |
|
return heatmap |
|
|
|
def create_sparse_motions(self, source_image, kp_driving, kp_source): |
|
""" |
|
Eq 4. in the paper T_{s<-d}(z) |
|
""" |
|
bs, _, h, w = source_image.shape |
|
identity_grid = make_coordinate_grid( |
|
(h, w), type=kp_source['value'].dtype) |
|
identity_grid = identity_grid.reshape([1, 1, h, w, 2]) |
|
coordinate_grid = identity_grid - kp_driving['value'].reshape( |
|
[bs, self.num_kp, 1, 1, 2]) |
|
if 'jacobian' in kp_driving: |
|
jacobian = paddle.matmul(kp_source['jacobian'], |
|
paddle.inverse(kp_driving['jacobian'])) |
|
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) |
|
# Todo: fix bug of paddle.tile |
|
p_jacobian = jacobian.reshape([bs, self.num_kp, 1, 1, 4]) |
|
paddle_jacobian = paddle.tile(p_jacobian, [1, 1, h, w, 1]) |
|
paddle_jacobian = paddle_jacobian.reshape( |
|
[bs, self.num_kp, h, w, 2, 2]) |
|
|
|
coordinate_grid = paddle.matmul(paddle_jacobian, |
|
coordinate_grid.unsqueeze(-1)) |
|
|
|
coordinate_grid = coordinate_grid.squeeze(-1) |
|
|
|
driving_to_source = coordinate_grid + kp_source['value'].reshape( |
|
[bs, self.num_kp, 1, 1, 2]) |
|
|
|
#adding background feature |
|
identity_grid = paddle.tile(identity_grid, (bs, 1, 1, 1, 1)) |
|
sparse_motions = paddle.concat( |
|
[identity_grid, driving_to_source], axis=1) |
|
return sparse_motions |
|
|
|
def create_deformed_source_image(self, source_image, sparse_motions): |
|
""" |
|
Eq 7. in the paper \hat{T}_{s<-d}(z) |
|
""" |
|
bs, _, h, w = source_image.shape |
|
source_repeat = paddle.tile( |
|
source_image.unsqueeze(1).unsqueeze(1), |
|
[1, self.num_kp + 1, 1, 1, 1, |
|
1]) #.repeat(1, self.num_kp + 1, 1, 1, 1, 1) |
|
source_repeat = source_repeat.reshape( |
|
[bs * (self.num_kp + 1), -1, h, w]) |
|
sparse_motions = sparse_motions.reshape( |
|
(bs * (self.num_kp + 1), h, w, -1)) |
|
sparse_deformed = F.grid_sample( |
|
source_repeat, |
|
sparse_motions, |
|
mode='bilinear', |
|
padding_mode='zeros', |
|
align_corners=True) |
|
sparse_deformed = sparse_deformed.reshape( |
|
(bs, self.num_kp + 1, -1, h, w)) |
|
return sparse_deformed |
|
|
|
def forward(self, source_image, kp_driving, kp_source): |
|
if self.scale_factor != 1: |
|
source_image = self.down(source_image) |
|
|
|
bs, _, h, w = source_image.shape |
|
|
|
out_dict = dict() |
|
heatmap_representation = self.create_heatmap_representations( |
|
source_image, kp_driving, kp_source) |
|
sparse_motion = self.create_sparse_motions(source_image, kp_driving, |
|
kp_source) |
|
deformed_source = self.create_deformed_source_image(source_image, |
|
sparse_motion) |
|
out_dict['sparse_deformed'] = deformed_source |
|
|
|
temp = paddle.concat([heatmap_representation, deformed_source], axis=2) |
|
temp = temp.reshape([bs, -1, h, w]) |
|
|
|
prediction = self.hourglass(temp) |
|
|
|
mask = self.mask(prediction) |
|
mask = F.softmax(mask, axis=1) |
|
out_dict['mask'] = mask |
|
mask = mask.unsqueeze(2) |
|
sparse_motion = sparse_motion.transpose([0, 1, 4, 2, 3]) |
|
deformation = (sparse_motion * mask).sum(axis=1) |
|
deformation = deformation.transpose([0, 2, 3, 1]) |
|
|
|
out_dict['deformation'] = deformation |
|
|
|
# Sec. 3.2 in the paper |
|
if self.occlusion: |
|
occlusion_map = F.sigmoid(self.occlusion(prediction)) |
|
out_dict['occlusion_map'] = occlusion_map |
|
|
|
return out_dict
|
|
|