parent
af2c5527e1
commit
9d964c5e9d
8 changed files with 379 additions and 16 deletions
@ -0,0 +1,96 @@ |
||||
# -*- coding: utf-8 -*- |
||||
''' |
||||
@Author: CaptainHu |
||||
@Date: 2022年 10月 27日 星期二 10:32:12 CST |
||||
@Description: 张量调试自用 |
||||
''' |
||||
from typing import Optional,Union |
||||
import math |
||||
from copy import deepcopy |
||||
|
||||
import matplotlib.pyplot as plt |
||||
import numpy as np |
||||
import torch |
||||
from torchvision.utils import make_grid |
||||
|
||||
|
||||
|
||||
def normalization(data): |
||||
_range = np.max(data) - np.min(data) |
||||
return (data - np.min(data)) / _range *255 |
||||
|
||||
|
||||
def show_img(img_ori,text:Optional[str]=None,cvreader:bool=True,delay:float=0): |
||||
img=deepcopy(img_ori) |
||||
if isinstance(img, list) or isinstance(img, tuple): |
||||
img_num = len(img) |
||||
row_n = math.ceil(math.sqrt(img_num)) |
||||
col_n = max(math.ceil(img_num / row_n), 1) |
||||
fig, axs = plt.subplots(row_n, col_n, figsize=(15 * row_n, 15 * col_n),layout="constrained") |
||||
for idx, img_ in enumerate(img): |
||||
if isinstance(img_,torch.Tensor) or isinstance(img_,np.ndarray): |
||||
img_=show_tensor(img_,cvreader) |
||||
if 2 == len(axs.shape): |
||||
axs[idx % row_n][idx // row_n].imshow(img_) |
||||
axs[idx % row_n][idx // row_n].set_title(str(idx)) |
||||
else: |
||||
axs[idx % row_n].imshow(img_) |
||||
axs[idx % row_n].set_title(str(idx)) |
||||
if text: |
||||
plt.text(0,0,text,fontsize=15) |
||||
if delay <=0: |
||||
plt.show() |
||||
else: |
||||
plt.draw() |
||||
plt.pause(delay) |
||||
plt.close() |
||||
elif isinstance(img,torch.Tensor) or isinstance(img,np.ndarray): |
||||
img=show_tensor(img,cvreader) |
||||
plt.rcParams['figure.constrained_layout.use'] = True |
||||
plt.imshow(img) |
||||
if text: |
||||
plt.text(0,0,text,fontsize=15) |
||||
if delay <=0: |
||||
plt.show() |
||||
else: |
||||
plt.draw() |
||||
plt.pause(delay) |
||||
plt.close() |
||||
else: |
||||
if hasattr(img,'show'): |
||||
img.show() |
||||
|
||||
def _2grid(img:Union[torch.Tensor,np.ndarray]): |
||||
if len(img.shape) ==3 and img.shape[0] not in (1,3): |
||||
img=img.unsqueeze(1) if isinstance(img,torch.Tensor) else np.expand_dims(img,axis=1) |
||||
r=math.floor(math.sqrt(img.shape[0])) |
||||
if isinstance(img,np.ndarray): |
||||
img=torch.Tensor(img) |
||||
img=make_grid(img,normalize=True,nrow=r) |
||||
else: |
||||
img=make_grid(img,normalize=True,nrow=r) |
||||
img: np.ndarray = img.detach().cpu().numpy().squeeze() |
||||
return img |
||||
|
||||
|
||||
def show_tensor(img,cvreader): |
||||
if img.shape[-1] < 4 and len(img.shape) == 3: |
||||
img=np.transpose(img,(2,0,1)) if isinstance(img,np.ndarray) else torch.permute(img,(2,0,1)) |
||||
if (len(img.shape) == 4 and img.shape[0] !=1) or (len(img.shape) ==3 and img.shape[0] not in (1,3)): |
||||
img=_2grid(img) |
||||
if isinstance(img,torch.Tensor): |
||||
img=img.detach().cpu().numpy().squeeze() |
||||
if 1==img.max(): |
||||
img=img.astype(np.float) |
||||
else: |
||||
img=img.astype(np.uint8) |
||||
if img.shape[0] == 1 or img.shape[0] == 3: |
||||
img = np.transpose(img, (1, 2, 0)) |
||||
if img.min() <0 or img.max() >255: |
||||
img=normalization(img) |
||||
print("img have norm") |
||||
if img.shape[-1] == 3: |
||||
img=img.astype(np.uint8) |
||||
if cvreader and len(img.shape)==3: |
||||
img=img[:,:,::-1] |
||||
return img |
@ -0,0 +1,247 @@ |
||||
# -*- coding: utf-8 -*- |
||||
''' |
||||
@Author: captainfffsama |
||||
@Date: 2022-10-27 14:22:09 |
||||
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||
@LastEditTime: 2022-10-27 17:07:33 |
||||
@FilePath: /BasesHomo/infer.py |
||||
@Description: |
||||
''' |
||||
"""Evaluates the model""" |
||||
|
||||
import argparse |
||||
import logging |
||||
import os |
||||
import cv2, imageio |
||||
|
||||
import numpy as np |
||||
import torch |
||||
from torch.autograd import Variable |
||||
|
||||
import dataset.data_loader as data_loader |
||||
import model.net as net |
||||
from common import utils |
||||
from loss.losses import compute_losses, compute_eval_results |
||||
from common.manager import Manager |
||||
import debug_tools as D |
||||
|
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json") |
||||
parser.add_argument('--restore_file', default='experiments/base_model/best_0.5012.pth.tar', help="name of the file in --model_dir containing weights to load") |
||||
|
||||
|
||||
|
||||
def evaluate(model, manager): |
||||
"""Evaluate the model on `num_steps` batches. |
||||
|
||||
Args: |
||||
model: (torch.nn.Module) the neural network |
||||
manager: a class instance that contains objects related to train and evaluate. |
||||
""" |
||||
print("eval begin!") |
||||
|
||||
# loss status and eval status initial |
||||
manager.reset_loss_status() |
||||
manager.reset_metric_status(manager.params.eval_type) |
||||
model.eval() |
||||
|
||||
RE = ['0000011', '0000016', '00000147', '00000155', '00000158', '00000107', '00000239', '0000030'] |
||||
LT = ['0000038', '0000044', '0000046', '0000047', '00000238', '00000177', '00000188', '00000181'] |
||||
LL = ['0000085', '00000100', '0000091', '0000092', '00000216', '00000226'] |
||||
SF = ['00000244', '00000251', '0000026', '0000030', '0000034', '00000115'] |
||||
LF = ['00000104', '0000031', '0000035', '00000129', '00000141', '00000200'] |
||||
MSE_RE = [] |
||||
MSE_LT = [] |
||||
MSE_LL = [] |
||||
MSE_SF = [] |
||||
MSE_LF = [] |
||||
|
||||
k = 0 |
||||
with torch.no_grad(): |
||||
# compute metrics over the dataset |
||||
|
||||
for data_batch in manager.dataloaders[manager.params.eval_type]: |
||||
# data parse |
||||
imgs_full = data_batch["imgs_ori"] |
||||
video_name = data_batch["video_name"] |
||||
gray_patches = data_batch["imgs_gray_patch"] |
||||
npy_name = data_batch["npy_name"] |
||||
# move to GPU if available |
||||
data_batch = utils.tensor_gpu(data_batch) |
||||
# compute model output |
||||
output_batch = model(data_batch) |
||||
# compute all metrics on this batch |
||||
eval_results = compute_eval_results(data_batch, output_batch, manager.params) |
||||
img1s_full_warp = eval_results["img1_full_warp"] |
||||
err_avg = eval_results["errs"] |
||||
|
||||
for j in range(len(err_avg)): |
||||
k += 1 |
||||
img2_full = imgs_full[j, 3:, ...].permute(1, 2, 0).cpu().numpy().astype(np.uint8) |
||||
img1_full_warp = img1s_full_warp[j].permute(1, 2, 0).cpu().numpy().astype(np.uint8) |
||||
gray_patch = gray_patches[j].permute(1, 2, 0).cpu().numpy().astype(np.uint8) |
||||
img2_full = cv2.cvtColor(img2_full, cv2.COLOR_BGR2RGB) |
||||
img1_full_warp = cv2.cvtColor(img1_full_warp, cv2.COLOR_BGR2RGB) |
||||
|
||||
if video_name[j] in RE: |
||||
MSE_RE.append(err_avg[j]) |
||||
elif video_name[j] in LT: |
||||
MSE_LT.append(err_avg[j]) |
||||
elif video_name[j] in LL: |
||||
MSE_LL.append(err_avg[j]) |
||||
elif video_name[j] in SF: |
||||
MSE_SF.append(err_avg[j]) |
||||
elif video_name[j] in LF: |
||||
MSE_LF.append(err_avg[j]) |
||||
|
||||
if k % 200 == 0: |
||||
print(k) |
||||
|
||||
print("{}:{}".format(k, err_avg[j])) |
||||
eval_save_result([img2_full, img1_full_warp], npy_name[j] + "_" + str(err_avg[j]) + ".gif", manager) |
||||
|
||||
MSE_RE_avg = sum(MSE_RE) / len(MSE_RE) |
||||
MSE_LT_avg = sum(MSE_LT) / len(MSE_LT) |
||||
MSE_LL_avg = sum(MSE_LL) / len(MSE_LL) |
||||
MSE_SF_avg = sum(MSE_SF) / len(MSE_SF) |
||||
MSE_LF_avg = sum(MSE_LF) / len(MSE_LF) |
||||
MSE_avg = (MSE_RE_avg + MSE_LT_avg + MSE_LL_avg + MSE_SF_avg + MSE_LF_avg) / 5 |
||||
|
||||
Metric = {"MSE_RE_avg":MSE_RE_avg, "MSE_LT_avg":MSE_LT_avg, "MSE_LL_avg":MSE_LL_avg, "MSE_SF_avg":MSE_SF_avg, "MSE_LF_avg":MSE_LF_avg, "AVG":MSE_avg} |
||||
manager.update_metric_status(metrics=Metric, split=manager.params.eval_type, batch_size=1) |
||||
|
||||
# update data to logger |
||||
manager.logger.info("Loss/valid epoch_{} {}: {:.2f}. RE:{:.4f} LT:{:.4f} LL:{:.4f} SF:{:.4f} LF:{:.4f} ".format(manager.params.eval_type, manager.epoch_val, |
||||
MSE_avg, MSE_RE_avg, MSE_LT_avg, MSE_LL_avg, MSE_SF_avg, MSE_LF_avg)) |
||||
|
||||
# For each epoch, print the metric |
||||
manager.print_metrics(manager.params.eval_type, title=manager.params.eval_type, color="green") |
||||
|
||||
# manager.epoch_val += 1 |
||||
|
||||
model.train() |
||||
|
||||
def eval_save_result(save_file, save_name, manager): |
||||
|
||||
# save dir: model_dir |
||||
save_dir_gif = os.path.join(manager.params.model_dir, 'gif') |
||||
if not os.path.exists(save_dir_gif): |
||||
os.makedirs(save_dir_gif) |
||||
|
||||
save_dir_gif_epoch = os.path.join(save_dir_gif, str(manager.epoch_val)) |
||||
if not os.path.exists(save_dir_gif_epoch): |
||||
os.makedirs(save_dir_gif_epoch) |
||||
|
||||
print("save_dir is: {}".format(os.path.join(save_dir_gif_epoch, save_name))) |
||||
|
||||
if type(save_file)==list: # save gif |
||||
utils.create_gif(save_file, os.path.join(save_dir_gif_epoch, save_name)) |
||||
elif type(save_file)==str: # save string information |
||||
f = open(os.path.join(save_dir_gif_epoch, save_name), 'w') |
||||
f.write(save_file) |
||||
f.close() |
||||
elif manager.val_img_save: # save single image |
||||
cv2.imwrite(os.path.join(save_dir_gif_epoch, save_name), save_file) |
||||
|
||||
def infer(model,img1p,img2p): |
||||
src_img1,img1_g,img1_rs=get_img_postprocess(img1p) |
||||
src_img2,img2_g,img2_rs=get_img_postprocess(img2p) |
||||
imgs_ori = torch.tensor(np.concatenate([src_img1, src_img2], axis=2)).permute(2, 0, 1).float() |
||||
imgs_rs = torch.tensor(np.concatenate([img1_rs, img2_rs], axis=2)).permute(2, 0, 1).float().cuda() |
||||
imgs_g = torch.tensor(np.concatenate([img1_g, img2_g], axis=2)).permute(2, 0, 1).float().cuda() |
||||
data_dict = {} |
||||
data_dict["imgs_gray_full"] = imgs_g.unsqueeze_(0) |
||||
data_dict["imgs_gray_patch"] =imgs_rs.unsqueeze_(0) |
||||
data_dict["start"] = torch.tensor([0, 0]).reshape(2, 1, 1).float().cuda().unsqueeze_(0) |
||||
data_dict["imgs_ori"] = imgs_ori.unsqueeze_(0).cuda() |
||||
|
||||
model.eval() |
||||
with torch.no_grad(): |
||||
output_batch = model(data_dict) |
||||
eval(data_dict, output_batch) |
||||
|
||||
|
||||
def eval(in_dict,out_dict): |
||||
imgs_full =in_dict["imgs_ori"] |
||||
batch_size, _, grid_h, grid_w = imgs_full.shape |
||||
|
||||
H_flow_f, H_flow_b =out_dict['H_flow'] |
||||
H_flow_f = net.upsample2d_flow_as(H_flow_f, imgs_full, mode="bilinear", if_rate=True) # scale |
||||
H_flow_b = net.upsample2d_flow_as(H_flow_b, imgs_full, mode="bilinear", if_rate=True) |
||||
img1_full_warp = net.get_warp_flow(imgs_full[:, :3, ...], H_flow_b, start=0) |
||||
img2_full_warp = net.get_warp_flow(imgs_full[:, 3:, ... ], H_flow_f, start=0) |
||||
# D.show_img([img1_full_warp[0],imgs_full[0, :3, ...], imgs_full[0, 3:, ...],img2_full_warp[0]]) |
||||
D.show_img(H_flow_f[0]) |
||||
|
||||
def test_data_aug(img): |
||||
mean_I = np.array([118.93, 113.97, 102.60]).reshape(1, 1, 3) |
||||
std_I = np.array([69.85, 68.81, 72.45]).reshape(1, 1, 3) |
||||
img = (img - mean_I) / std_I |
||||
img = np.mean(img, axis=2, keepdims=True) |
||||
return img |
||||
|
||||
def get_img_postprocess(img_path): |
||||
src_img=cv2.imread(img_path) |
||||
img=cv2.resize(src_img, (640, 360)) |
||||
img_rs=cv2.resize(img, (576,320)) |
||||
img_rs=test_data_aug(img_rs) |
||||
img=test_data_aug(img) |
||||
return src_img,img,img_rs |
||||
|
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
""" |
||||
Evaluate the model on the test set. |
||||
""" |
||||
# Load the parameters |
||||
args = parser.parse_args() |
||||
json_path = os.path.join(args.model_dir, 'params.json') |
||||
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) |
||||
params = utils.Params(json_path) |
||||
# Only load model weights |
||||
params.only_weights = True |
||||
|
||||
# Update args into params |
||||
params.update(vars(args)) |
||||
|
||||
# Use GPU if available |
||||
params.cuda = torch.cuda.is_available() |
||||
|
||||
# Set the random seed for reproducible experiments |
||||
torch.manual_seed(230) |
||||
if params.cuda: |
||||
torch.cuda.manual_seed(230) |
||||
|
||||
# Get the logger |
||||
logger = utils.set_logger(os.path.join(args.model_dir, 'evaluate.log')) |
||||
|
||||
# Create the input data pipeline |
||||
logging.info("Creating the dataset...") |
||||
|
||||
# Fetch dataloaders |
||||
dataloaders = data_loader.fetch_dataloader(params) |
||||
|
||||
# Define the model and optimizer |
||||
if params.cuda: |
||||
model = net.fetch_net(params).cuda() |
||||
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) |
||||
else: |
||||
model = net.fetch_net(params) |
||||
|
||||
|
||||
# Initial status for checkpoint manager |
||||
manager = Manager(model=model, optimizer=None, scheduler=None, params=params, dataloaders=dataloaders, writer=None, logger=logger) |
||||
|
||||
# Reload weights from the saved file |
||||
manager.load_checkpoints() |
||||
|
||||
|
||||
# Test the model |
||||
logger.info("Starting test") |
||||
img1="/home/chiebotgpuhq/Pictures/test_tmp/1.jpeg" |
||||
img2="/home/chiebotgpuhq/Pictures/test_tmp/3.jpeg" |
||||
infer(manager.model,img1,img2) |
||||
|
||||
# Evaluate |
||||
# evaluate(model, manager) |
Loading…
Reference in new issue