import numpy as np import torch import torch.nn as nn from model import net import debug_tools as D def triplet_loss(a, p, n, margin=1.0, exp=1, reduce=False, size_average=False): triplet_loss = nn.TripletMarginLoss(margin=margin, p=exp, reduce=reduce, size_average=size_average) return triplet_loss(a, p, n) def photo_loss_function(diff, q, averge=True): diff = (torch.abs(diff) + 0.01).pow(q) if averge: loss_mean = diff.mean() else: loss_mean = diff.sum() return loss_mean def geometricDistance(correspondence, flow): flow = flow.permute(1, 2, 0).cpu().detach().numpy() p1 = correspondence[0] # 0 p2 = correspondence[1] # 1 if isinstance(correspondence[1][0], float): result = p2 - (p1 - flow[int(p1[1]), int(p1[0])]) error = np.linalg.norm(result) else: result = [p2 - (p1 - flow[p1[1], p1[0]]), p1 - (p2 - flow[p2[1], p2[0]])] error = min(np.linalg.norm(result[0]), np.linalg.norm(result[1])) return error def compute_losses(output, train_batch, params): losses = {} # compute losses if params.loss_type == "basic": imgs_patch = train_batch['imgs_gray_patch'] start = train_batch['start'] H_flow_f, H_flow_b = output['H_flow'] fea1_full, fea2_full = output["fea_full"] fea1_patch, fea2_patch = output["fea_patch"] img1_warp, img2_warp = output["img_warp"] fea1_patch_warp, fea2_patch_warp = output["fea_patch_warp"] batch_size, _, h_patch, w_patch = imgs_patch.size() fea2_warp = net.get_warp_flow(fea2_full, H_flow_f, start=start) fea1_warp = net.get_warp_flow(fea1_full, H_flow_b, start=start) im_diff_fw = imgs_patch[:, :1, ...] - img2_warp im_diff_bw = imgs_patch[:, 1:, ...] - img1_warp fea_diff_fw = fea1_warp - fea1_patch_warp fea_diff_bw = fea2_warp - fea2_patch_warp # loss losses["photo_loss_l1"] = photo_loss_function(diff=im_diff_fw, q=1, averge=True) + photo_loss_function(diff=im_diff_bw, q=1, averge=True) losses["fea_loss_l1"] = photo_loss_function(diff=fea_diff_fw, q=1, averge=True) + photo_loss_function(diff=fea_diff_bw, q=1, averge=True) losses["triplet_loss"] = triplet_loss(fea1_patch, fea2_warp, fea2_patch).mean() + triplet_loss(fea2_patch, fea1_warp, fea1_patch).mean() # loss toal: backward needed losses["total"] = losses["triplet_loss"] + params.weight_fil * losses["fea_loss_l1"] else: raise NotImplementedError return losses def compute_eval_results(data_batch, output_batch, manager): imgs_full = data_batch["imgs_ori"] points = data_batch["points"] batch_size, _, grid_h, grid_w = imgs_full.shape H_flow_f, H_flow_b = output_batch['H_flow'] H_flow_f = net.upsample2d_flow_as(H_flow_f, imgs_full, mode="bilinear", if_rate=True) # scale H_flow_b = net.upsample2d_flow_as(H_flow_b, imgs_full, mode="bilinear", if_rate=True) img1_full_warp = net.get_warp_flow(imgs_full[:, :3, ...], H_flow_b, start=0) img2_full_warp = net.get_warp_flow(imgs_full[:, 3:, ... ], H_flow_f, start=0) errs = [] errs_p = [] for i in range(len(points)): # len(points) point = eval(points[i]) err = 0 tmp = [] for j in range(6): # len(point['matche_pts']) points_value = point['matche_pts'][j] err_p = geometricDistance(points_value, H_flow_f[i]) err += err_p tmp.append(err_p) errs.append(err / (j + 1)) errs_p.append(tmp) # ==================================================================== return ====================================================================== eval_results = {} eval_results["img1_full_warp"] = img1_full_warp eval_results["errs"] = errs eval_results["errs_p"] = errs_p return eval_results