mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
327 lines
9.6 KiB
327 lines
9.6 KiB
def tokenize(s): |
|
tokens = [] |
|
token = "" |
|
isString = False |
|
isComment = False |
|
for symbol in s: |
|
isComment = (isComment and symbol != '\n') or (not isString and symbol == '#') |
|
if isComment: |
|
continue |
|
|
|
if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \ |
|
symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \ |
|
symbol == ',': |
|
|
|
if (symbol == '\"' or symbol == '\'') and isString: |
|
tokens.append(token) |
|
token = "" |
|
else: |
|
if isString: |
|
token += symbol |
|
elif token: |
|
tokens.append(token) |
|
token = "" |
|
isString = (symbol == '\"' or symbol == '\'') ^ isString; |
|
|
|
elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']': |
|
if token: |
|
tokens.append(token) |
|
token = "" |
|
tokens.append(symbol) |
|
else: |
|
token += symbol |
|
if token: |
|
tokens.append(token) |
|
return tokens |
|
|
|
|
|
def parseMessage(tokens, idx): |
|
msg = {} |
|
assert(tokens[idx] == '{') |
|
|
|
isArray = False |
|
while True: |
|
if not isArray: |
|
idx += 1 |
|
if idx < len(tokens): |
|
fieldName = tokens[idx] |
|
else: |
|
return None |
|
if fieldName == '}': |
|
break |
|
|
|
idx += 1 |
|
fieldValue = tokens[idx] |
|
|
|
if fieldValue == '{': |
|
embeddedMsg, idx = parseMessage(tokens, idx) |
|
if fieldName in msg: |
|
msg[fieldName].append(embeddedMsg) |
|
else: |
|
msg[fieldName] = [embeddedMsg] |
|
elif fieldValue == '[': |
|
isArray = True |
|
elif fieldValue == ']': |
|
isArray = False |
|
else: |
|
if fieldName in msg: |
|
msg[fieldName].append(fieldValue) |
|
else: |
|
msg[fieldName] = [fieldValue] |
|
return msg, idx |
|
|
|
|
|
def readTextMessage(filePath): |
|
with open(filePath, 'rt') as f: |
|
content = f.read() |
|
|
|
tokens = tokenize('{' + content + '}') |
|
msg = parseMessage(tokens, 0) |
|
return msg[0] if msg else {} |
|
|
|
|
|
def listToTensor(values): |
|
if all([isinstance(v, float) for v in values]): |
|
dtype = 'DT_FLOAT' |
|
field = 'float_val' |
|
elif all([isinstance(v, int) for v in values]): |
|
dtype = 'DT_INT32' |
|
field = 'int_val' |
|
else: |
|
raise Exception('Wrong values types') |
|
|
|
msg = { |
|
'tensor': { |
|
'dtype': dtype, |
|
'tensor_shape': { |
|
'dim': { |
|
'size': len(values) |
|
} |
|
} |
|
} |
|
} |
|
msg['tensor'][field] = values |
|
return msg |
|
|
|
|
|
def addConstNode(name, values, graph_def): |
|
node = NodeDef() |
|
node.name = name |
|
node.op = 'Const' |
|
node.addAttr('value', values) |
|
graph_def.node.extend([node]) |
|
|
|
|
|
def addSlice(inp, out, begins, sizes, graph_def): |
|
beginsNode = NodeDef() |
|
beginsNode.name = out + '/begins' |
|
beginsNode.op = 'Const' |
|
beginsNode.addAttr('value', begins) |
|
graph_def.node.extend([beginsNode]) |
|
|
|
sizesNode = NodeDef() |
|
sizesNode.name = out + '/sizes' |
|
sizesNode.op = 'Const' |
|
sizesNode.addAttr('value', sizes) |
|
graph_def.node.extend([sizesNode]) |
|
|
|
sliced = NodeDef() |
|
sliced.name = out |
|
sliced.op = 'Slice' |
|
sliced.input.append(inp) |
|
sliced.input.append(beginsNode.name) |
|
sliced.input.append(sizesNode.name) |
|
graph_def.node.extend([sliced]) |
|
|
|
|
|
def addReshape(inp, out, shape, graph_def): |
|
shapeNode = NodeDef() |
|
shapeNode.name = out + '/shape' |
|
shapeNode.op = 'Const' |
|
shapeNode.addAttr('value', shape) |
|
graph_def.node.extend([shapeNode]) |
|
|
|
reshape = NodeDef() |
|
reshape.name = out |
|
reshape.op = 'Reshape' |
|
reshape.input.append(inp) |
|
reshape.input.append(shapeNode.name) |
|
graph_def.node.extend([reshape]) |
|
|
|
|
|
def addSoftMax(inp, out, graph_def): |
|
softmax = NodeDef() |
|
softmax.name = out |
|
softmax.op = 'Softmax' |
|
softmax.addAttr('axis', -1) |
|
softmax.input.append(inp) |
|
graph_def.node.extend([softmax]) |
|
|
|
|
|
def addFlatten(inp, out, graph_def): |
|
flatten = NodeDef() |
|
flatten.name = out |
|
flatten.op = 'Flatten' |
|
flatten.input.append(inp) |
|
graph_def.node.extend([flatten]) |
|
|
|
|
|
class NodeDef: |
|
def __init__(self): |
|
self.input = [] |
|
self.name = "" |
|
self.op = "" |
|
self.attr = {} |
|
|
|
def addAttr(self, key, value): |
|
assert(not key in self.attr) |
|
if isinstance(value, bool): |
|
self.attr[key] = {'b': value} |
|
elif isinstance(value, int): |
|
self.attr[key] = {'i': value} |
|
elif isinstance(value, float): |
|
self.attr[key] = {'f': value} |
|
elif isinstance(value, str): |
|
self.attr[key] = {'s': value} |
|
elif isinstance(value, list): |
|
self.attr[key] = listToTensor(value) |
|
else: |
|
raise Exception('Unknown type of attribute ' + key) |
|
|
|
def Clear(self): |
|
self.input = [] |
|
self.name = "" |
|
self.op = "" |
|
self.attr = {} |
|
|
|
|
|
class GraphDef: |
|
def __init__(self): |
|
self.node = [] |
|
|
|
def save(self, filePath): |
|
with open(filePath, 'wt') as f: |
|
|
|
def printAttr(d, indent): |
|
indent = ' ' * indent |
|
for key, value in sorted(d.items(), key=lambda x:x[0].lower()): |
|
value = value if isinstance(value, list) else [value] |
|
for v in value: |
|
if isinstance(v, dict): |
|
f.write(indent + key + ' {\n') |
|
printAttr(v, len(indent) + 2) |
|
f.write(indent + '}\n') |
|
else: |
|
isString = False |
|
if isinstance(v, str) and not v.startswith('DT_'): |
|
try: |
|
float(v) |
|
except: |
|
isString = True |
|
|
|
if isinstance(v, bool): |
|
printed = 'true' if v else 'false' |
|
elif v == 'true' or v == 'false': |
|
printed = 'true' if v == 'true' else 'false' |
|
elif isString: |
|
printed = '\"%s\"' % v |
|
else: |
|
printed = str(v) |
|
f.write(indent + key + ': ' + printed + '\n') |
|
|
|
for node in self.node: |
|
f.write('node {\n') |
|
f.write(' name: \"%s\"\n' % node.name) |
|
f.write(' op: \"%s\"\n' % node.op) |
|
for inp in node.input: |
|
f.write(' input: \"%s\"\n' % inp) |
|
for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()): |
|
f.write(' attr {\n') |
|
f.write(' key: \"%s\"\n' % key) |
|
f.write(' value {\n') |
|
printAttr(value, 6) |
|
f.write(' }\n') |
|
f.write(' }\n') |
|
f.write('}\n') |
|
|
|
|
|
def parseTextGraph(filePath): |
|
msg = readTextMessage(filePath) |
|
|
|
graph = GraphDef() |
|
for node in msg['node']: |
|
graphNode = NodeDef() |
|
graphNode.name = node['name'][0] |
|
graphNode.op = node['op'][0] |
|
graphNode.input = node['input'] if 'input' in node else [] |
|
|
|
if 'attr' in node: |
|
for attr in node['attr']: |
|
graphNode.attr[attr['key'][0]] = attr['value'][0] |
|
|
|
graph.node.append(graphNode) |
|
return graph |
|
|
|
|
|
# Removes Identity nodes |
|
def removeIdentity(graph_def): |
|
identities = {} |
|
for node in graph_def.node: |
|
if node.op == 'Identity': |
|
identities[node.name] = node.input[0] |
|
graph_def.node.remove(node) |
|
|
|
for node in graph_def.node: |
|
for i in range(len(node.input)): |
|
if node.input[i] in identities: |
|
node.input[i] = identities[node.input[i]] |
|
|
|
|
|
def removeUnusedNodesAndAttrs(to_remove, graph_def): |
|
unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', |
|
'Index', 'Tperm', 'is_training', 'Tpaddings'] |
|
|
|
removedNodes = [] |
|
|
|
for i in reversed(range(len(graph_def.node))): |
|
op = graph_def.node[i].op |
|
name = graph_def.node[i].name |
|
|
|
if op == 'Const' or to_remove(name, op): |
|
if op != 'Const': |
|
removedNodes.append(name) |
|
|
|
del graph_def.node[i] |
|
else: |
|
for attr in unusedAttrs: |
|
if attr in graph_def.node[i].attr: |
|
del graph_def.node[i].attr[attr] |
|
|
|
# Remove references to removed nodes except Const nodes. |
|
for node in graph_def.node: |
|
for i in reversed(range(len(node.input))): |
|
if node.input[i] in removedNodes: |
|
del node.input[i] |
|
|
|
|
|
def writeTextGraph(modelPath, outputPath, outNodes): |
|
try: |
|
import cv2 as cv |
|
|
|
cv.dnn.writeTextGraph(modelPath, outputPath) |
|
except: |
|
import tensorflow as tf |
|
from tensorflow.tools.graph_transforms import TransformGraph |
|
|
|
with tf.gfile.FastGFile(modelPath, 'rb') as f: |
|
graph_def = tf.GraphDef() |
|
graph_def.ParseFromString(f.read()) |
|
|
|
graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order']) |
|
|
|
for node in graph_def.node: |
|
if node.op == 'Const': |
|
if 'value' in node.attr: |
|
del node.attr['value'] |
|
|
|
tf.train.write_graph(graph_def, "", outputPath, as_text=True)
|
|
|