mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
3.4 KiB
77 lines
3.4 KiB
import numpy as np |
|
import sys |
|
import os |
|
import argparse |
|
import tensorflow as tf |
|
from tensorflow.python.platform import gfile |
|
from imagenet_cls_test_alexnet import MeanValueFetch, DnnCaffeModel, Framework, ClsAccEvaluation |
|
try: |
|
import cv2 as cv |
|
except ImportError: |
|
raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, ' |
|
'configure environemnt variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)') |
|
|
|
# If you've got an exception "Cannot load libmkl_avx.so or libmkl_def.so" or similar, try to export next variable |
|
# before runnigng the script: |
|
# LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_sequential.so |
|
|
|
|
|
class TensorflowModel(Framework): |
|
sess = tf.Session |
|
output = tf.Graph |
|
|
|
def __init__(self, model_file, in_blob_name, out_blob_name): |
|
self.in_blob_name = in_blob_name |
|
self.sess = tf.Session() |
|
with gfile.FastGFile(model_file, 'rb') as f: |
|
graph_def = tf.GraphDef() |
|
graph_def.ParseFromString(f.read()) |
|
self.sess.graph.as_default() |
|
tf.import_graph_def(graph_def, name='') |
|
self.output = self.sess.graph.get_tensor_by_name(out_blob_name + ":0") |
|
|
|
def get_name(self): |
|
return 'Tensorflow' |
|
|
|
def get_output(self, input_blob): |
|
assert len(input_blob.shape) == 4 |
|
batch_tf = input_blob.transpose(0, 2, 3, 1) |
|
out = self.sess.run(self.output, |
|
{self.in_blob_name+':0': batch_tf}) |
|
out = out[..., 1:1001] |
|
return out |
|
|
|
|
|
class DnnTfInceptionModel(DnnCaffeModel): |
|
net = cv.dnn.Net() |
|
|
|
def __init__(self, model_file, in_blob_name, out_blob_name): |
|
self.net = cv.dnn.readNetFromTensorflow(model_file) |
|
self.in_blob_name = in_blob_name |
|
self.out_blob_name = out_blob_name |
|
|
|
def get_output(self, input_blob): |
|
return super(DnnTfInceptionModel, self).get_output(input_blob)[..., 1:1001] |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir") |
|
parser.add_argument("--img_cls_file", help="path to file with classes ids for images, download it here:" |
|
"https://github.com/opencv/opencv_extra/tree/master/testdata/dnn/img_classes_inception.txt") |
|
parser.add_argument("--model", help="path to tensorflow model, download it here:" |
|
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip") |
|
parser.add_argument("--log", help="path to logging file") |
|
parser.add_argument("--batch_size", help="size of images in batch", default=1) |
|
parser.add_argument("--frame_size", help="size of input image", default=224) |
|
parser.add_argument("--in_blob", help="name for input blob", default='input') |
|
parser.add_argument("--out_blob", help="name for output blob", default='softmax2') |
|
args = parser.parse_args() |
|
|
|
data_fetcher = MeanValueFetch(args.frame_size, args.imgs_dir, True) |
|
|
|
frameworks = [TensorflowModel(args.model, args.in_blob, args.out_blob), |
|
DnnTfInceptionModel(args.model, '', args.out_blob)] |
|
|
|
acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size) |
|
acc_eval.process(frameworks, data_fetcher)
|
|
|