refactor python ROC script and add axis ticks

pull/322/head
marina.kolpakova 12 years ago
parent 990ca86de6
commit 922de414ef
  1. 60
      apps/sft/misc/sft.py

@ -1,10 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
import cv2, re, glob import cv2, re, glob
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
""" Convert numpy matrices with rectangles and confidences to sorted list of detections.""" """ Convert numPy matrices with rectangles and confidences to sorted list of detections."""
def convert2detections(rects, confs, crop_factor = 0.125): def convert2detections(rects, confs, crop_factor = 0.125):
if rects is None: if rects is None:
return [] return []
@ -14,11 +14,13 @@ def convert2detections(rects, confs, crop_factor = 0.125):
dts = [Detection(r,c) for r, c in dts] dts = [Detection(r,c) for r, c in dts]
dts.sort(lambda x, y : -1 if (x.conf - y.conf) > 0 else 1) dts.sort(lambda x, y : -1 if (x.conf - y.conf) > 0 else 1)
for dt in dts: for dt in dts:
dt.crop(crop_factor) dt.crop(crop_factor)
return dts return dts
""" Create new instance of soft cascade."""
def cascade(min_scale, max_scale, nscales, f): def cascade(min_scale, max_scale, nscales, f):
# where we use nms cv::SCascade::DOLLAR == 2 # where we use nms cv::SCascade::DOLLAR == 2
c = cv2.SCascade(min_scale, max_scale, nscales, 2) c = cv2.SCascade(min_scale, max_scale, nscales, 2)
@ -27,6 +29,7 @@ def cascade(min_scale, max_scale, nscales, f):
assert c.load(dom) assert c.load(dom)
return c return c
""" Compute prefix sum for en array"""
def cumsum(n): def cumsum(n):
cum = [] cum = []
y = 0 y = 0
@ -35,6 +38,7 @@ def cumsum(n):
cum.append(y) cum.append(y)
return cum return cum
""" Compute x and y arrays for ROC plot"""
def computeROC(confidenses, tp, nannotated, nframes): def computeROC(confidenses, tp, nannotated, nframes):
confidenses, tp = zip(*sorted(zip(confidenses, tp), reverse = True)) confidenses, tp = zip(*sorted(zip(confidenses, tp), reverse = True))
@ -46,34 +50,52 @@ def computeROC(confidenses, tp, nannotated, nframes):
return fppi, miss_rate return fppi, miss_rate
""" Crop rectangle by factor"""
def crop_rect(rect, factor): def crop_rect(rect, factor):
val_x = factor * float(rect[2]) val_x = factor * float(rect[2])
val_y = factor * float(rect[3]) val_y = factor * float(rect[3])
x = [int(rect[0] + val_x), int(rect[1] + val_y), int(rect[2] - 2.0 * val_x), int(rect[3] - 2.0 * val_y)] x = [int(rect[0] + val_x), int(rect[1] + val_y), int(rect[2] - 2.0 * val_x), int(rect[3] - 2.0 * val_y)]
return x return x
# """Initialize plot axises"""
def initPlot(name = "ROC curve Bahnhof"):
def initPlot():
fig, ax = plt.subplots() fig, ax = plt.subplots()
fig.canvas.draw() fig.canvas.draw()
plt.xlabel("fppi") plt.xlabel("fppi")
plt.ylabel("miss rate") plt.ylabel("miss rate")
plt.title("ROC curve Bahnhof") plt.title(name)
plt.grid(True) plt.grid(True)
plt.xscale('log') plt.xscale('log')
plt.yscale('log') plt.yscale('log')
def showPlot(name): """Show resulted plot"""
plt.savefig(name) def showPlot(file_name):
# plt.savefig(file_name)
plt.axis((pow(10, -3), pow(10, 1), 0.0, 1))
plt.yticks( [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.64, 0.8, 1], ['.05', '.10', '.20', '.30', '.40', '.50', '.64', '.80', '1'] )
plt.show() plt.show()
def match(gts, dts):
# Cartesian product for each detection BB_dt with each BB_gt
overlaps = [[dt.overlap(gt) for gt in gts]for dt in dts]
matches_gt = [0]*len(gts)
matches_dt = [0]*len(dts)
for idx, row in enumerate(overlaps):
imax = row.index(max(row))
if (matches_gt[imax] == 0 and row[imax] > 0.5):
matches_gt[imax] = 1
matches_dt[idx] = 1
return matches_dt
def plotLogLog(fppi, miss_rate, c): def plotLogLog(fppi, miss_rate, c):
plt.semilogy(fppi, miss_rate, color = c, linewidth = 2) print
plt.loglog(fppi, miss_rate, color = c, linewidth = 2)
def draw_rects(img, rects, color, l = lambda x, y : x + y): def draw_rects(img, rects, color, l = lambda x, y : x + y):
@ -102,7 +124,7 @@ class Detection:
def crop(self, factor): def crop(self, factor):
self.bb = crop_rect(self.bb, factor) self.bb = crop_rect(self.bb, factor)
# we use rect-stype for dt and box style for gt. ToDo: fix it # we use rect-style for dt and box style for gt. ToDo: fix it
def overlap(self, b): def overlap(self, b):
a = self.bb a = self.bb
@ -155,19 +177,3 @@ def norm_box(box, ratio):
def norm_acpect_ratio(boxes, ratio): def norm_acpect_ratio(boxes, ratio):
return [ norm_box(box, ratio) for box in boxes] return [ norm_box(box, ratio) for box in boxes]
def match(gts, dts):
# Cartesian product for each detection BB_dt with each BB_gt
overlaps = [[dt.overlap(gt) for gt in gts]for dt in dts]
matches_gt = [0]*len(gts)
matches_dt = [0]*len(dts)
for idx, row in enumerate(overlaps):
imax = row.index(max(row))
if (matches_gt[imax] == 0 and row[imax] > 0.5):
matches_gt[imax] = 1
matches_dt[idx] = 1
return matches_dt
Loading…
Cancel
Save