add detection to ground truth matching

according to Piotr Dollar paper
pull/322/head
marina.kolpakova 12 years ago
parent d1952f28d9
commit 4c4c878b1b
  1. 21
      apps/sft/misk/roc_test.py
  2. 95
      apps/sft/misk/sft.py

@ -7,6 +7,11 @@ import sys, os, os.path, glob, math, cv2
from datetime import datetime from datetime import datetime
import numpy import numpy
# "key" : ( b, g, r)
bgr = { "red" : ( 0, 0, 255),
"green" : ( 0, 255, 0),
"blue" : (255, 0 , 0)}
def call_parser(f, a): def call_parser(f, a):
return eval( "sft.parse_" + f + "('" + a + "')") return eval( "sft.parse_" + f + "('" + a + "')")
@ -37,10 +42,10 @@ if __name__ == "__main__":
dom = xml.getFirstTopLevelNode() dom = xml.getFirstTopLevelNode()
assert cascade.load(dom) assert cascade.load(dom)
frame = 0
pattern = args.input pattern = args.input
camera = cv2.VideoCapture(args.input) camera = cv2.VideoCapture(pattern)
frame = 0
while True: while True:
ret, img = camera.read() ret, img = camera.read()
if not ret: if not ret:
@ -53,17 +58,17 @@ if __name__ == "__main__":
boxes = samples[tail] boxes = samples[tail]
boxes = sft.norm_acpect_ratio(boxes, 0.5) boxes = sft.norm_acpect_ratio(boxes, 0.5)
if boxes is not None:
sft.draw_rects(img, boxes, (255, 0, 0), lambda x, y : y)
frame = frame + 1 frame = frame + 1
rects, confs = cascade.detect(img, rois = None) rects, confs = cascade.detect(img, rois = None)
dt_old = sft.match(boxes, rects, confs) dts = sft.convert2detections(rects, confs)
sft.draw_dt(img, dts, bgr["green"])
fp, fn = sft.match(boxes, dts)
print "fp and fn", fp, fn
if dt_old is not None:
sft.draw_dt(img, dt_old, (0, 255, 0))
sft.draw_rects(img, boxes, bgr["blue"], lambda x, y : y)
cv2.imshow("result", img); cv2.imshow("result", img);
if (cv2.waitKey (0) == 27): if (cv2.waitKey (0) == 27):
break; break;

@ -4,6 +4,29 @@ import cv2, re, glob
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
""" Convert numpy matrices with rectangles and confidences to sorted list of detections."""
def convert2detections(rects, confs, crop_factor = 0.125):
if rects is None:
return []
dts = zip(*[rects.tolist(), confs.tolist()])
dts = zip(dts[0][0], dts[0][1])
dts = [Detection(r,c) for r, c in dts]
dts.sort(lambda x, y : -1 if (x.conf - y.conf) > 0 else 1)
for dt in dts:
dt.crop(crop_factor)
return dts
def crop_rect(rect, factor):
val_x = factor * float(rect[2])
val_y = factor * float(rect[3])
x = [int(rect[0] + val_x), int(rect[1] + val_y), int(rect[2] - 2.0 * val_x), int(rect[3] - 2.0 * val_y)]
return x
#
def plot_curve(): def plot_curve():
fig, ax = plt.subplots() fig, ax = plt.subplots()
@ -29,12 +52,6 @@ def plot_curve():
plt.xscale('log') plt.xscale('log')
plt.show() plt.show()
def crop_rect(rect, factor):
val_x = factor * float(rect[2])
val_y = factor * float(rect[3])
x = [int(rect[0] + val_x), int(rect[1] + val_y), int(rect[2] - 2.0 * val_x), int(rect[3] - 2.0 * val_y)]
return x
def draw_rects(img, rects, color, l = lambda x, y : x + y): def draw_rects(img, rects, color, l = lambda x, y : x + y):
if rects is not None: if rects is not None:
for x1, y1, x2, y2 in rects: for x1, y1, x2, y2 in rects:
@ -58,16 +75,13 @@ class Detection:
self.conf = conf self.conf = conf
self.matched = False self.matched = False
# def crop(self):
# rel_scale = self.bb[1] / 128
def crop(self, factor): def crop(self, factor):
print "was", self.bb
self.bb = crop_rect(self.bb, factor) self.bb = crop_rect(self.bb, factor)
print "bec", self.bb
# we use rect-stype for dt and box style for gt. ToDo: fix it # we use rect-stype for dt and box style for gt. ToDo: fix it
def overlap(self, b): def overlap(self, b):
print self.bb, "vs", b
a = self.bb a = self.bb
w = min( a[0] + a[2], b[2]) - max(a[0], b[0]); w = min( a[0] + a[2], b[2]) - max(a[0], b[0]);
h = min( a[1] + a[3], b[3]) - max(a[1], b[1]); h = min( a[1] + a[3], b[3]) - max(a[1], b[1]);
@ -120,47 +134,40 @@ def norm_acpect_ratio(boxes, ratio):
return [ norm_box(box, ratio) for box in boxes] return [ norm_box(box, ratio) for box in boxes]
def match(gts, rects, confs): def match(gts, dts):
if rects is None:
return 0
fp = 0
fn = 0
dts = zip(*[rects.tolist(), confs.tolist()])
dts = zip(dts[0][0], dts[0][1])
dts = [Detection(r,c) for r, c in dts]
factor = 1.0 / 8.0
dt_old = dts
for dt in dts: for dt in dts:
dt.crop(factor) print dt.bb,
print
for gt in gts: for gt in gts:
print gt
# exclude small
if gt[2] - gt[0] < 27:
continue
matched = False # Cartesian product for each detection BB_dt with each BB_gt
overlaps = [[dt.overlap(gt) for gt in gts]for dt in dts]
print overlaps
for dt in dts: matches_gt = [0]*len(gts)
# dt.crop() print matches_gt
overlap = dt.overlap(gt)
print dt.bb, "vs", gt, overlap
if overlap > 0.5:
dt.mark_matched()
matched = True
print "matched ", dt.bb, gt
if not matched: matches_dt = [0]*len(dts)
fn = fn + 1 print matches_dt
print "fn", fn for idx, row in enumerate(overlaps):
print idx, row
for dt in dts: imax = row.index(max(row))
if not dt.matched:
fp = fp + 1 if (matches_gt[imax] == 0 and row[imax] > 0.5):
matches_gt[imax] = 1
matches_dt[idx] = 1
print matches_gt
print matches_dt
fp = sum(1 for x in matches_dt if x == 0)
fn = sum(1 for x in matches_gt if x == 0)
print "fp", fp return fp, fn
return dt_old
Loading…
Cancel
Save