You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
268 lines
10 KiB
268 lines
10 KiB
#!/usr/bin/env python |
|
from __future__ import print_function |
|
import os, sys, shutil |
|
import argparse |
|
import json, re |
|
from subprocess import check_output |
|
import datetime |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def load_json(path): |
|
f = open(path, "r") |
|
data = json.load(f) |
|
return data |
|
|
|
|
|
def save_json(obj, path): |
|
tmp_file = path + ".bak" |
|
f = open(tmp_file, "w") |
|
json.dump(obj, f, indent=2) |
|
f.flush() |
|
os.fsync(f.fileno()) |
|
f.close() |
|
try: |
|
os.rename(tmp_file, path) |
|
except: |
|
os.remove(path) |
|
os.rename(tmp_file, path) |
|
|
|
|
|
def parse_evaluation_result(input_str, i): |
|
res = {} |
|
res['frame_number'] = i + 1 |
|
res['error'] = {} |
|
regex = "([A-Za-z. \\[\\].0-9]+):[ ]*([0-9]*\.[0-9]+|[0-9]+)" |
|
for elem in re.findall(regex,input_str): |
|
if "Time" in elem[0]: |
|
res['time'] = float(elem[1]) |
|
elif "Average" in elem[0]: |
|
res['error']['average'] = float(elem[1]) |
|
elif "deviation" in elem[0]: |
|
res['error']['std'] = float(elem[1]) |
|
else: |
|
res['error'][elem[0]] = float(elem[1]) |
|
return res |
|
|
|
|
|
def evaluate_sequence(sequence, algorithm, dataset, executable, img_files, gt_files, |
|
state, state_path): |
|
if "eval_results" not in state[dataset][algorithm][-1].keys(): |
|
state[dataset][algorithm][-1]["eval_results"] = {} |
|
elif sequence in state[dataset][algorithm][-1]["eval_results"].keys(): |
|
return |
|
|
|
res = [] |
|
for i in range(len(img_files) - 1): |
|
sys.stdout.write("Algorithm: %-20s Sequence: %-10s Done: [%3d/%3d]\r" % |
|
(algorithm, sequence, i, len(img_files) - 1)), |
|
sys.stdout.flush() |
|
|
|
res_string = check_output([executable, img_files[i], img_files[i + 1], |
|
algorithm, gt_files[i]]) |
|
res.append(parse_evaluation_result(res_string, i)) |
|
state[dataset][algorithm][-1]["eval_results"][sequence] = res |
|
save_json(state, state_path) |
|
|
|
#############################DATSET DEFINITIONS################################ |
|
|
|
def evaluate_mpi_sintel(source_dir, algorithm, evaluation_executable, state, state_path): |
|
evaluation_result = {} |
|
img_dir = os.path.join(source_dir, 'mpi_sintel', 'training', 'final') |
|
gt_dir = os.path.join(source_dir, 'mpi_sintel', 'training', 'flow') |
|
sequences = [f for f in os.listdir(img_dir) |
|
if os.path.isdir(os.path.join(img_dir, f))] |
|
for seq in sequences: |
|
img_files = sorted([os.path.join(img_dir, seq, f) |
|
for f in os.listdir(os.path.join(img_dir, seq)) |
|
if f.endswith(".png")]) |
|
gt_files = sorted([os.path.join(gt_dir, seq, f) |
|
for f in os.listdir(os.path.join(gt_dir, seq)) |
|
if f.endswith(".flo")]) |
|
evaluation_result[seq] = evaluate_sequence(seq, algorithm, 'mpi_sintel', |
|
evaluation_executable, img_files, gt_files, state, state_path) |
|
return evaluation_result |
|
|
|
|
|
def evaluate_middlebury(source_dir, algorithm, evaluation_executable, state, state_path): |
|
evaluation_result = {} |
|
img_dir = os.path.join(source_dir, 'middlebury', 'other-data') |
|
gt_dir = os.path.join(source_dir, 'middlebury', 'other-gt-flow') |
|
sequences = [f for f in os.listdir(gt_dir) |
|
if os.path.isdir(os.path.join(gt_dir, f))] |
|
for seq in sequences: |
|
img_files = sorted([os.path.join(img_dir, seq, f) |
|
for f in os.listdir(os.path.join(img_dir, seq)) |
|
if f.endswith(".png")]) |
|
gt_files = sorted([os.path.join(gt_dir, seq, f) |
|
for f in os.listdir(os.path.join(gt_dir, seq)) |
|
if f.endswith(".flo")]) |
|
evaluation_result[seq] = evaluate_sequence(seq, algorithm, 'middlebury', |
|
evaluation_executable, img_files, gt_files, state, state_path) |
|
return evaluation_result |
|
|
|
|
|
dataset_eval_functions = { |
|
"mpi_sintel": evaluate_mpi_sintel, |
|
"middlebury": evaluate_middlebury |
|
} |
|
|
|
############################################################################### |
|
|
|
def create_dir(dir): |
|
if not os.path.exists(dir): |
|
os.makedirs(dir) |
|
|
|
|
|
def parse_sequence(input_str): |
|
if len(input_str) == 0: |
|
return [] |
|
else: |
|
return [o.strip() for o in input_str.split(",") if o] |
|
|
|
|
|
def build_chart(dst_folder, state, dataset): |
|
fig = plt.figure(figsize=(16, 10)) |
|
markers = ["o", "s", "h", "^", "D"] |
|
marker_idx = 0 |
|
colors = ["b", "g", "r"] |
|
color_idx = 0 |
|
for algo in state[dataset].keys(): |
|
for eval_instance in state[dataset][algo]: |
|
name = algo + "--" + eval_instance["timestamp"] |
|
average_time = 0.0 |
|
average_error = 0.0 |
|
num_elem = 0 |
|
for seq in eval_instance["eval_results"].keys(): |
|
for frame in eval_instance["eval_results"][seq]: |
|
average_time += frame["time"] |
|
average_error += frame["error"]["average"] |
|
num_elem += 1 |
|
average_time /= num_elem |
|
average_error /= num_elem |
|
|
|
marker_style = colors[color_idx] + markers[marker_idx] |
|
color_idx += 1 |
|
if color_idx >= len(colors): |
|
color_idx = 0 |
|
marker_idx += 1 |
|
if marker_idx >= len(markers): |
|
marker_idx = 0 |
|
plt.gca().plot([average_time], [average_error], |
|
marker_style, |
|
markersize=14, |
|
label=name) |
|
|
|
plt.gca().set_ylabel('Average Endpoint Error (EPE)', fontsize=20) |
|
plt.gca().set_xlabel('Average Runtime (seconds per frame)', fontsize=20) |
|
plt.gca().set_xscale("log") |
|
plt.gca().set_title('Evaluation on ' + dataset, fontsize=20) |
|
|
|
plt.gca().legend() |
|
fig.savefig(os.path.join(dst_folder, "evaluation_results_" + dataset + ".png"), |
|
bbox_inches='tight') |
|
plt.close() |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
description='Optical flow benchmarking script', |
|
formatter_class=argparse.RawDescriptionHelpFormatter) |
|
parser.add_argument( |
|
"bin_path", |
|
default="./optflow-example-optical_flow_evaluation", |
|
help="Path to the optical flow evaluation executable") |
|
parser.add_argument( |
|
"-a", |
|
"--algorithms", |
|
metavar="ALGORITHMS", |
|
default="", |
|
help=("Comma-separated list of optical-flow algorithms to evaluate " |
|
"(example: -a farneback,tvl1,deepflow). Note that previously " |
|
"evaluated algorithms are also included in the output charts")) |
|
parser.add_argument( |
|
"-d", |
|
"--datasets", |
|
metavar="DATASETS", |
|
default="mpi_sintel", |
|
help=("Comma-separated list of datasets for evaluation (currently only " |
|
"'mpi_sintel' and 'middlebury' are supported)")) |
|
parser.add_argument( |
|
"-f", |
|
"--dataset_folder", |
|
metavar="DATASET_FOLDER", |
|
default="./OF_datasets", |
|
help=("Path to a folder containing datasets. To enable evaluation on " |
|
"MPI Sintel dataset, please download it using the following links: " |
|
"http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_images.zip and " |
|
"http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_extras.zip and " |
|
"unzip these archives into the 'mpi_sintel' folder. To enable evaluation " |
|
"on the Middlebury dataset use the following links: " |
|
"http://vision.middlebury.edu/flow/data/comp/zip/other-color-twoframes.zip, " |
|
"http://vision.middlebury.edu/flow/data/comp/zip/other-gt-flow.zip. " |
|
"These should be unzipped into 'middlebury' folder")) |
|
parser.add_argument( |
|
"-o", |
|
"--out", |
|
metavar="OUT_DIR", |
|
default="./OF_evaluation_results", |
|
help="Output directory where to store benchmark results") |
|
parser.add_argument( |
|
"-s", |
|
"--state", |
|
metavar="STATE_JSON", |
|
default="./OF_evaluation_state.json", |
|
help=("Path to a json file that stores the current evaluation state and " |
|
"previous evaluation results")) |
|
args, other_args = parser.parse_known_args() |
|
|
|
if not os.path.isfile(args.bin_path): |
|
print("Error: " + args.bin_path + " does not exist") |
|
sys.exit(1) |
|
|
|
if not os.path.exists(args.dataset_folder): |
|
print("Error: " + args.dataset_folder + (" does not exist. Please, correctly " |
|
"specify the -f parameter")) |
|
sys.exit(1) |
|
|
|
state = {} |
|
if os.path.isfile(args.state): |
|
state = load_json(args.state) |
|
|
|
algorithm_list = parse_sequence(args.algorithms) |
|
dataset_list = parse_sequence(args.datasets) |
|
for dataset in dataset_list: |
|
if dataset not in dataset_eval_functions.keys(): |
|
print("Error: unsupported dataset " + dataset) |
|
sys.exit(1) |
|
if dataset not in os.listdir(args.dataset_folder): |
|
print("Error: " + os.path.join(args.dataset_folder, dataset) + (" does not exist. " |
|
"Please, download the dataset and follow the naming conventions " |
|
"(use -h for more information)")) |
|
sys.exit(1) |
|
|
|
for dataset in dataset_list: |
|
if dataset not in state.keys(): |
|
state[dataset] = {} |
|
for algorithm in algorithm_list: |
|
if algorithm in state[dataset].keys(): |
|
last_eval_instance = state[dataset][algorithm][-1] |
|
if "finished" not in last_eval_instance.keys(): |
|
print(("Continuing an unfinished evaluation of " + |
|
algorithm + " started at " + last_eval_instance["timestamp"])) |
|
else: |
|
state[dataset][algorithm].append({"timestamp": |
|
datetime.datetime.now().strftime("%Y-%m-%d--%H-%M")}) |
|
else: |
|
state[dataset][algorithm] = [{"timestamp": |
|
datetime.datetime.now().strftime("%Y-%m-%d--%H-%M")}] |
|
save_json(state, args.state) |
|
dataset_eval_functions[dataset](args.dataset_folder, algorithm, args.bin_path, |
|
state, args.state) |
|
state[dataset][algorithm][-1]["finished"] = True |
|
save_json(state, args.state) |
|
save_json(state, args.state) |
|
|
|
create_dir(args.out) |
|
for dataset in dataset_list: |
|
build_chart(args.out, state, dataset)
|
|
|