You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
3.8 KiB
109 lines
3.8 KiB
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from model import net |
|
|
|
def triplet_loss(a, p, n, margin=1.0, exp=1, reduce=False, size_average=False): |
|
|
|
triplet_loss = nn.TripletMarginLoss(margin=margin, p=exp, reduce=reduce, size_average=size_average) |
|
return triplet_loss(a, p, n) |
|
|
|
def photo_loss_function(diff, q, averge=True): |
|
diff = (torch.abs(diff) + 0.01).pow(q) |
|
if averge: |
|
loss_mean = diff.mean() |
|
else: |
|
loss_mean = diff.sum() |
|
return loss_mean |
|
|
|
def geometricDistance(correspondence, flow): |
|
flow = flow.permute(1, 2, 0).cpu().detach().numpy() |
|
|
|
p1 = correspondence[0] # 0 |
|
p2 = correspondence[1] # 1 |
|
|
|
if isinstance(correspondence[1][0], float): |
|
result = p2 - (p1 - flow[int(p1[1]), int(p1[0])]) |
|
error = np.linalg.norm(result) |
|
else: |
|
result = [p2 - (p1 - flow[p1[1], p1[0]]), p1 - (p2 - flow[p2[1], p2[0]])] |
|
error = min(np.linalg.norm(result[0]), np.linalg.norm(result[1])) |
|
|
|
return error |
|
|
|
def compute_losses(output, train_batch, params): |
|
losses = {} |
|
|
|
# compute losses |
|
if params.loss_type == "basic": |
|
imgs_patch = train_batch['imgs_gray_patch'] |
|
start = train_batch['start'] |
|
|
|
H_flow_f, H_flow_b = output['H_flow'] |
|
fea1_full, fea2_full = output["fea_full"] |
|
fea1_patch, fea2_patch = output["fea_patch"] |
|
img1_warp, img2_warp = output["img_warp"] |
|
fea1_patch_warp, fea2_patch_warp = output["fea_patch_warp"] |
|
|
|
batch_size, _, h_patch, w_patch = imgs_patch.size() |
|
|
|
fea2_warp = net.get_warp_flow(fea2_full, H_flow_f, start=start) |
|
fea1_warp = net.get_warp_flow(fea1_full, H_flow_b, start=start) |
|
|
|
im_diff_fw = imgs_patch[:, :1, ...] - img2_warp |
|
im_diff_bw = imgs_patch[:, 1:, ...] - img1_warp |
|
|
|
fea_diff_fw = fea1_warp - fea1_patch_warp |
|
fea_diff_bw = fea2_warp - fea2_patch_warp |
|
|
|
# loss |
|
losses["photo_loss_l1"] = photo_loss_function(diff=im_diff_fw, q=1, averge=True) + photo_loss_function(diff=im_diff_bw, q=1, averge=True) |
|
|
|
losses["fea_loss_l1"] = photo_loss_function(diff=fea_diff_fw, q=1, averge=True) + photo_loss_function(diff=fea_diff_bw, q=1, averge=True) |
|
|
|
losses["triplet_loss"] = triplet_loss(fea1_patch, fea2_warp, fea2_patch).mean() + triplet_loss(fea2_patch, fea1_warp, fea1_patch).mean() |
|
|
|
# loss toal: backward needed |
|
losses["total"] = losses["triplet_loss"] + params.weight_fil * losses["fea_loss_l1"] |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
return losses |
|
|
|
def compute_eval_results(data_batch, output_batch, manager): |
|
|
|
imgs_full = data_batch["imgs_ori"] |
|
points = data_batch["points"] |
|
batch_size, _, grid_h, grid_w = imgs_full.shape |
|
|
|
H_flow_f, H_flow_b = output_batch['H_flow'] |
|
|
|
H_flow_f = net.upsample2d_flow_as(H_flow_f, imgs_full, mode="bilinear", if_rate=True) # scale |
|
H_flow_b = net.upsample2d_flow_as(H_flow_b, imgs_full, mode="bilinear", if_rate=True) |
|
img1_full_warp = net.get_warp_flow(imgs_full[:, :3, ...], H_flow_b, start=0) |
|
img2_full_warp = net.get_warp_flow(imgs_full[:, 3:, ... ], H_flow_f, start=0) |
|
|
|
errs = [] |
|
errs_p = [] |
|
|
|
for i in range(len(points)): # len(points) |
|
point = eval(points[i]) |
|
err = 0 |
|
tmp = [] |
|
for j in range(6): # len(point['matche_pts']) |
|
points_value = point['matche_pts'][j] |
|
err_p = geometricDistance(points_value, H_flow_f[i]) |
|
err += err_p |
|
tmp.append(err_p) |
|
|
|
errs.append(err / (j + 1)) |
|
errs_p.append(tmp) |
|
# ==================================================================== return ====================================================================== |
|
|
|
eval_results = {} |
|
eval_results["img1_full_warp"] = img1_full_warp |
|
eval_results["errs"] = errs |
|
eval_results["errs_p"] = errs_p |
|
|
|
return eval_results
|
|
|