mirror of https://github.com/opencv/opencv.git
parent
44572fac44
commit
bf87a43185
8 changed files with 457 additions and 5 deletions
@ -0,0 +1,108 @@ |
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
|
||||
namespace cv { namespace dnn { |
||||
|
||||
class CropAndResizeLayerImpl CV_FINAL : public CropAndResizeLayer |
||||
{ |
||||
public: |
||||
CropAndResizeLayerImpl(const LayerParams& params) |
||||
{ |
||||
CV_Assert(params.has("width"), params.has("height")); |
||||
outWidth = params.get<float>("width"); |
||||
outHeight = params.get<float>("height"); |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int requiredOutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 2, inputs[0].size() == 4); |
||||
if (inputs[0][0] != 1) |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
outputs.resize(1, MatShape(4)); |
||||
outputs[0][0] = inputs[1][2]; // Number of bounding boxes.
|
||||
outputs[0][1] = inputs[0][1]; // Number of channels.
|
||||
outputs[0][2] = outHeight; |
||||
outputs[0][3] = outWidth; |
||||
return false; |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr); |
||||
} |
||||
|
||||
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Mat& inp = *inputs[0]; |
||||
Mat& out = outputs[0]; |
||||
Mat boxes = inputs[1]->reshape(1, inputs[1]->total() / 7); |
||||
const int numChannels = inp.size[1]; |
||||
const int inpHeight = inp.size[2]; |
||||
const int inpWidth = inp.size[3]; |
||||
const int inpSpatialSize = inpHeight * inpWidth; |
||||
const int outSpatialSize = outHeight * outWidth; |
||||
CV_Assert(inp.isContinuous(), out.isContinuous()); |
||||
|
||||
for (int b = 0; b < boxes.rows; ++b) |
||||
{ |
||||
float* outDataBox = out.ptr<float>(b); |
||||
float left = boxes.at<float>(b, 3); |
||||
float top = boxes.at<float>(b, 4); |
||||
float right = boxes.at<float>(b, 5); |
||||
float bottom = boxes.at<float>(b, 6); |
||||
float boxWidth = right - left; |
||||
float boxHeight = bottom - top; |
||||
|
||||
float heightScale = boxHeight * static_cast<float>(inpHeight - 1) / (outHeight - 1); |
||||
float widthScale = boxWidth * static_cast<float>(inpWidth - 1) / (outWidth - 1); |
||||
for (int y = 0; y < outHeight; ++y) |
||||
{ |
||||
float input_y = top * (inpHeight - 1) + y * heightScale; |
||||
int y0 = static_cast<int>(input_y); |
||||
const float* inpData_row0 = (float*)inp.data + y0 * inpWidth; |
||||
const float* inpData_row1 = (y0 + 1 < inpHeight) ? (inpData_row0 + inpWidth) : inpData_row0; |
||||
for (int x = 0; x < outWidth; ++x) |
||||
{ |
||||
float input_x = left * (inpWidth - 1) + x * widthScale; |
||||
int x0 = static_cast<int>(input_x); |
||||
int x1 = std::min(x0 + 1, inpWidth - 1); |
||||
|
||||
float* outData = outDataBox + y * outWidth + x; |
||||
const float* inpData_row0_c = inpData_row0; |
||||
const float* inpData_row1_c = inpData_row1; |
||||
for (int c = 0; c < numChannels; ++c) |
||||
{ |
||||
*outData = inpData_row0_c[x0] + |
||||
(input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) + |
||||
(input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] + |
||||
(input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0])); |
||||
|
||||
inpData_row0_c += inpSpatialSize; |
||||
inpData_row1_c += inpSpatialSize; |
||||
outData += outSpatialSize; |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
private: |
||||
int outWidth, outHeight; |
||||
}; |
||||
|
||||
Ptr<Layer> CropAndResizeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<CropAndResizeLayer>(new CropAndResizeLayerImpl(params)); |
||||
} |
||||
|
||||
} // namespace dnn
|
||||
} // namespace cv
|
@ -0,0 +1,291 @@ |
||||
import argparse |
||||
import numpy as np |
||||
import tensorflow as tf |
||||
|
||||
from tensorflow.core.framework.node_def_pb2 import NodeDef |
||||
from tensorflow.tools.graph_transforms import TransformGraph |
||||
from google.protobuf import text_format |
||||
|
||||
parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' |
||||
'SSD model from TensorFlow Object Detection API. ' |
||||
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') |
||||
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') |
||||
parser.add_argument('--output', required=True, help='Path to output text graph.') |
||||
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.') |
||||
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+', |
||||
help='Hyper-parameter of grid_anchor_generator from a config file.') |
||||
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+', |
||||
help='Hyper-parameter of grid_anchor_generator from a config file.') |
||||
parser.add_argument('--features_stride', default=16, type=float, nargs='+', |
||||
help='Hyper-parameter from a config file.') |
||||
args = parser.parse_args() |
||||
|
||||
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv', |
||||
'FirstStageBoxPredictor/BoxEncodingPredictor', |
||||
'FirstStageBoxPredictor/ClassPredictor', |
||||
'CropAndResize', |
||||
'MaxPool2D', |
||||
'SecondStageFeatureExtractor', |
||||
'SecondStageBoxPredictor', |
||||
'image_tensor') |
||||
|
||||
scopesToIgnore = ('FirstStageFeatureExtractor/Assert', |
||||
'FirstStageFeatureExtractor/Shape', |
||||
'FirstStageFeatureExtractor/strided_slice', |
||||
'FirstStageFeatureExtractor/GreaterEqual', |
||||
'FirstStageFeatureExtractor/LogicalAnd') |
||||
|
||||
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', |
||||
'Index', 'Tperm', 'is_training', 'Tpaddings'] |
||||
|
||||
# Read the graph. |
||||
with tf.gfile.FastGFile(args.input, 'rb') as f: |
||||
graph_def = tf.GraphDef() |
||||
graph_def.ParseFromString(f.read()) |
||||
|
||||
# Removes Identity nodes |
||||
def removeIdentity(): |
||||
identities = {} |
||||
for node in graph_def.node: |
||||
if node.op == 'Identity': |
||||
identities[node.name] = node.input[0] |
||||
graph_def.node.remove(node) |
||||
|
||||
for node in graph_def.node: |
||||
for i in range(len(node.input)): |
||||
if node.input[i] in identities: |
||||
node.input[i] = identities[node.input[i]] |
||||
|
||||
removeIdentity() |
||||
|
||||
removedNodes = [] |
||||
|
||||
for i in reversed(range(len(graph_def.node))): |
||||
op = graph_def.node[i].op |
||||
name = graph_def.node[i].name |
||||
|
||||
if op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep): |
||||
if op != 'Const': |
||||
removedNodes.append(name) |
||||
|
||||
del graph_def.node[i] |
||||
else: |
||||
for attr in unusedAttrs: |
||||
if attr in graph_def.node[i].attr: |
||||
del graph_def.node[i].attr[attr] |
||||
|
||||
# Remove references to removed nodes except Const nodes. |
||||
for node in graph_def.node: |
||||
for i in reversed(range(len(node.input))): |
||||
if node.input[i] in removedNodes: |
||||
del node.input[i] |
||||
|
||||
|
||||
# Connect input node to the first layer |
||||
assert(graph_def.node[0].op == 'Placeholder') |
||||
graph_def.node[1].input.insert(0, graph_def.node[0].name) |
||||
|
||||
# Temporarily remove top nodes. |
||||
topNodes = [] |
||||
while True: |
||||
node = graph_def.node.pop() |
||||
topNodes.append(node) |
||||
if node.op == 'CropAndResize': |
||||
break |
||||
|
||||
def tensorMsg(values): |
||||
if all([isinstance(v, float) for v in values]): |
||||
dtype = 'DT_FLOAT' |
||||
field = 'float_val' |
||||
elif all([isinstance(v, int) for v in values]): |
||||
dtype = 'DT_INT32' |
||||
field = 'int_val' |
||||
else: |
||||
raise Exception('Wrong values types') |
||||
|
||||
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values) |
||||
for value in values: |
||||
msg += '%s: %s ' % (field, str(value)) |
||||
return msg + '}' |
||||
|
||||
def addSlice(inp, out, begins, sizes): |
||||
beginsNode = NodeDef() |
||||
beginsNode.name = out + '/begins' |
||||
beginsNode.op = 'Const' |
||||
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"]) |
||||
graph_def.node.extend([beginsNode]) |
||||
|
||||
sizesNode = NodeDef() |
||||
sizesNode.name = out + '/sizes' |
||||
sizesNode.op = 'Const' |
||||
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"]) |
||||
graph_def.node.extend([sizesNode]) |
||||
|
||||
sliced = NodeDef() |
||||
sliced.name = out |
||||
sliced.op = 'Slice' |
||||
sliced.input.append(inp) |
||||
sliced.input.append(beginsNode.name) |
||||
sliced.input.append(sizesNode.name) |
||||
graph_def.node.extend([sliced]) |
||||
|
||||
def addReshape(inp, out, shape): |
||||
shapeNode = NodeDef() |
||||
shapeNode.name = out + '/shape' |
||||
shapeNode.op = 'Const' |
||||
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"]) |
||||
graph_def.node.extend([shapeNode]) |
||||
|
||||
reshape = NodeDef() |
||||
reshape.name = out |
||||
reshape.op = 'Reshape' |
||||
reshape.input.append(inp) |
||||
reshape.input.append(shapeNode.name) |
||||
graph_def.node.extend([reshape]) |
||||
|
||||
def addSoftMax(inp, out): |
||||
softmax = NodeDef() |
||||
softmax.name = out |
||||
softmax.op = 'Softmax' |
||||
text_format.Merge('i: -1', softmax.attr['axis']) |
||||
softmax.input.append(inp) |
||||
graph_def.node.extend([softmax]) |
||||
|
||||
addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', |
||||
'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2]) |
||||
|
||||
addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', |
||||
'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4 |
||||
|
||||
flatten = NodeDef() |
||||
flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten' # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd |
||||
flatten.op = 'Flatten' |
||||
flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') |
||||
graph_def.node.extend([flatten]) |
||||
|
||||
proposals = NodeDef() |
||||
proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) |
||||
proposals.op = 'PriorBox' |
||||
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') |
||||
proposals.input.append(graph_def.node[0].name) # image_tensor |
||||
|
||||
text_format.Merge('b: false', proposals.attr["flip"]) |
||||
text_format.Merge('b: true', proposals.attr["clip"]) |
||||
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"]) |
||||
text_format.Merge('f: 0.0', proposals.attr["offset"]) |
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"]) |
||||
|
||||
widths = [] |
||||
heights = [] |
||||
for a in args.aspect_ratios: |
||||
for s in args.scales: |
||||
ar = np.sqrt(a) |
||||
heights.append((args.features_stride**2) * s / ar) |
||||
widths.append((args.features_stride**2) * s * ar) |
||||
|
||||
text_format.Merge(tensorMsg(widths), proposals.attr["width"]) |
||||
text_format.Merge(tensorMsg(heights), proposals.attr["height"]) |
||||
|
||||
graph_def.node.extend([proposals]) |
||||
|
||||
# Compare with Reshape_5 |
||||
detectionOut = NodeDef() |
||||
detectionOut.name = 'detection_out' |
||||
detectionOut.op = 'DetectionOutput' |
||||
|
||||
detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') |
||||
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax') |
||||
detectionOut.input.append('proposals') |
||||
|
||||
text_format.Merge('i: 2', detectionOut.attr['num_classes']) |
||||
text_format.Merge('b: true', detectionOut.attr['share_location']) |
||||
text_format.Merge('i: 0', detectionOut.attr['background_label_id']) |
||||
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold']) |
||||
text_format.Merge('i: 6000', detectionOut.attr['top_k']) |
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) |
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) |
||||
text_format.Merge('b: true', detectionOut.attr['clip']) |
||||
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) |
||||
|
||||
graph_def.node.extend([detectionOut]) |
||||
|
||||
# Save as text. |
||||
for node in reversed(topNodes): |
||||
graph_def.node.extend([node]) |
||||
|
||||
addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax') |
||||
|
||||
addSlice('SecondStageBoxPredictor/Reshape_1/softmax', |
||||
'SecondStageBoxPredictor/Reshape_1/slice', |
||||
[0, 0, 1], [-1, -1, -1]) |
||||
|
||||
addReshape('SecondStageBoxPredictor/Reshape_1/slice', |
||||
'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1]) |
||||
|
||||
# Replace Flatten subgraph onto a single node. |
||||
for i in reversed(range(len(graph_def.node))): |
||||
if graph_def.node[i].op == 'CropAndResize': |
||||
graph_def.node[i].input.insert(1, 'detection_out') |
||||
|
||||
if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': |
||||
shapeNode = NodeDef() |
||||
shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2' |
||||
shapeNode.op = 'Const' |
||||
text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"]) |
||||
graph_def.node.extend([shapeNode]) |
||||
|
||||
graph_def.node[i].input.pop() |
||||
graph_def.node[i].input.append(shapeNode.name) |
||||
|
||||
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', |
||||
'SecondStageBoxPredictor/Flatten/flatten/strided_slice', |
||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']: |
||||
del graph_def.node[i] |
||||
|
||||
for node in graph_def.node: |
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape': |
||||
node.op = 'Flatten' |
||||
node.input.pop() |
||||
break |
||||
|
||||
################################################################################ |
||||
### Postprocessing |
||||
################################################################################ |
||||
addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4]) |
||||
|
||||
variance = NodeDef() |
||||
variance.name = 'proposals/variance' |
||||
variance.op = 'Const' |
||||
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"]) |
||||
graph_def.node.extend([variance]) |
||||
|
||||
varianceEncoder = NodeDef() |
||||
varianceEncoder.name = 'variance_encoded' |
||||
varianceEncoder.op = 'Mul' |
||||
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape') |
||||
varianceEncoder.input.append(variance.name) |
||||
text_format.Merge('i: 2', varianceEncoder.attr["axis"]) |
||||
graph_def.node.extend([varianceEncoder]) |
||||
|
||||
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1]) |
||||
|
||||
detectionOut = NodeDef() |
||||
detectionOut.name = 'detection_out_final' |
||||
detectionOut.op = 'DetectionOutput' |
||||
|
||||
detectionOut.input.append('variance_encoded') |
||||
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') |
||||
detectionOut.input.append('detection_out/slice/reshape') |
||||
|
||||
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes']) |
||||
text_format.Merge('b: false', detectionOut.attr['share_location']) |
||||
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id']) |
||||
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold']) |
||||
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) |
||||
text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) |
||||
text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) |
||||
text_format.Merge('b: true', detectionOut.attr['clip']) |
||||
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target']) |
||||
graph_def.node.extend([detectionOut]) |
||||
|
||||
tf.train.write_graph(graph_def, "", args.output, as_text=True) |
Loading…
Reference in new issue