You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.8 KiB

3 years ago
import numpy as np
import torch
import torch.nn as nn
from model import net
def triplet_loss(a, p, n, margin=1.0, exp=1, reduce=False, size_average=False):
triplet_loss = nn.TripletMarginLoss(margin=margin, p=exp, reduce=reduce, size_average=size_average)
return triplet_loss(a, p, n)
def photo_loss_function(diff, q, averge=True):
diff = (torch.abs(diff) + 0.01).pow(q)
if averge:
loss_mean = diff.mean()
else:
loss_mean = diff.sum()
return loss_mean
def geometricDistance(correspondence, flow):
flow = flow.permute(1, 2, 0).cpu().detach().numpy()
p1 = correspondence[0] # 0
p2 = correspondence[1] # 1
if isinstance(correspondence[1][0], float):
result = p2 - (p1 - flow[int(p1[1]), int(p1[0])])
error = np.linalg.norm(result)
else:
result = [p2 - (p1 - flow[p1[1], p1[0]]), p1 - (p2 - flow[p2[1], p2[0]])]
error = min(np.linalg.norm(result[0]), np.linalg.norm(result[1]))
return error
def compute_losses(output, train_batch, params):
losses = {}
# compute losses
if params.loss_type == "basic":
imgs_patch = train_batch['imgs_gray_patch']
start = train_batch['start']
H_flow_f, H_flow_b = output['H_flow']
fea1_full, fea2_full = output["fea_full"]
fea1_patch, fea2_patch = output["fea_patch"]
img1_warp, img2_warp = output["img_warp"]
fea1_patch_warp, fea2_patch_warp = output["fea_patch_warp"]
batch_size, _, h_patch, w_patch = imgs_patch.size()
fea2_warp = net.get_warp_flow(fea2_full, H_flow_f, start=start)
fea1_warp = net.get_warp_flow(fea1_full, H_flow_b, start=start)
im_diff_fw = imgs_patch[:, :1, ...] - img2_warp
im_diff_bw = imgs_patch[:, 1:, ...] - img1_warp
fea_diff_fw = fea1_warp - fea1_patch_warp
fea_diff_bw = fea2_warp - fea2_patch_warp
# loss
losses["photo_loss_l1"] = photo_loss_function(diff=im_diff_fw, q=1, averge=True) + photo_loss_function(diff=im_diff_bw, q=1, averge=True)
losses["fea_loss_l1"] = photo_loss_function(diff=fea_diff_fw, q=1, averge=True) + photo_loss_function(diff=fea_diff_bw, q=1, averge=True)
losses["triplet_loss"] = triplet_loss(fea1_patch, fea2_warp, fea2_patch).mean() + triplet_loss(fea2_patch, fea1_warp, fea1_patch).mean()
# loss toal: backward needed
losses["total"] = losses["triplet_loss"] + params.weight_fil * losses["fea_loss_l1"]
else:
raise NotImplementedError
return losses
def compute_eval_results(data_batch, output_batch, manager):
imgs_full = data_batch["imgs_ori"]
points = data_batch["points"]
batch_size, _, grid_h, grid_w = imgs_full.shape
H_flow_f, H_flow_b = output_batch['H_flow']
H_flow_f = net.upsample2d_flow_as(H_flow_f, imgs_full, mode="bilinear", if_rate=True) # scale
H_flow_b = net.upsample2d_flow_as(H_flow_b, imgs_full, mode="bilinear", if_rate=True)
img1_full_warp = net.get_warp_flow(imgs_full[:, :3, ...], H_flow_b, start=0)
img2_full_warp = net.get_warp_flow(imgs_full[:, 3:, ... ], H_flow_f, start=0)
errs = []
errs_p = []
3 years ago
3 years ago
for i in range(len(points)): # len(points)
point = eval(points[i])
err = 0
tmp = []
for j in range(6): # len(point['matche_pts'])
points_value = point['matche_pts'][j]
err_p = geometricDistance(points_value, H_flow_f[i])
err += err_p
tmp.append(err_p)
errs.append(err / (j + 1))
errs_p.append(tmp)
# ==================================================================== return ======================================================================
3 years ago
3 years ago
eval_results = {}
eval_results["img1_full_warp"] = img1_full_warp
eval_results["errs"] = errs
eval_results["errs_p"] = errs_p
return eval_results