mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.4 KiB
71 lines
2.4 KiB
#!/usr/bin/env python |
|
''' |
|
=============================================================================== |
|
Interactive Image Segmentation using GrabCut algorithm. |
|
=============================================================================== |
|
''' |
|
|
|
# Python 2/3 compatibility |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
import cv2 |
|
import sys |
|
|
|
from tests_common import NewOpenCVTests |
|
|
|
class grabcut_test(NewOpenCVTests): |
|
|
|
def verify(self, mask, exp): |
|
|
|
maxDiffRatio = 0.02 |
|
expArea = np.count_nonzero(exp) |
|
nonIntersectArea = np.count_nonzero(mask != exp) |
|
curRatio = float(nonIntersectArea) / expArea |
|
return curRatio < maxDiffRatio |
|
|
|
def scaleMask(self, mask): |
|
|
|
return np.where((mask==cv2.GC_FGD) + (mask==cv2.GC_PR_FGD),255,0).astype('uint8') |
|
|
|
def test_grabcut(self): |
|
|
|
img = self.get_sample('cv/shared/airplane.png') |
|
mask_prob = self.get_sample("cv/grabcut/mask_probpy.png", 0) |
|
exp_mask1 = self.get_sample("cv/grabcut/exp_mask1py.png", 0) |
|
exp_mask2 = self.get_sample("cv/grabcut/exp_mask2py.png", 0) |
|
|
|
if img is None: |
|
self.assertTrue(False, 'Missing test data') |
|
|
|
rect = (24, 126, 459, 168) |
|
mask = np.zeros(img.shape[:2], dtype = np.uint8) |
|
bgdModel = np.zeros((1,65),np.float64) |
|
fgdModel = np.zeros((1,65),np.float64) |
|
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv2.GC_INIT_WITH_RECT) |
|
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 2, cv2.GC_EVAL) |
|
|
|
if mask_prob is None: |
|
mask_prob = mask.copy() |
|
cv2.imwrite(self.extraTestDataPath + '/cv/grabcut/mask_probpy.png', mask_prob) |
|
if exp_mask1 is None: |
|
exp_mask1 = self.scaleMask(mask) |
|
cv2.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask1py.png', exp_mask1) |
|
|
|
self.assertEqual(self.verify(self.scaleMask(mask), exp_mask1), True) |
|
|
|
mask = mask_prob |
|
bgdModel = np.zeros((1,65),np.float64) |
|
fgdModel = np.zeros((1,65),np.float64) |
|
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv2.GC_INIT_WITH_MASK) |
|
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 1, cv2.GC_EVAL) |
|
|
|
if exp_mask2 is None: |
|
exp_mask2 = self.scaleMask(mask) |
|
cv2.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask2py.png', exp_mask2) |
|
|
|
self.assertEqual(self.verify(self.scaleMask(mask), exp_mask2), True) |
|
|
|
|
|
if __name__ == '__main__': |
|
NewOpenCVTests.bootstrap()
|
|
|