mirror of https://github.com/opencv/opencv.git
Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
25 lines
802 B
25 lines
802 B
import tensorflow as tf |
|
from tensorflow.core.framework.node_def_pb2 import NodeDef |
|
from google.protobuf import text_format |
|
|
|
def tensorMsg(values): |
|
if all([isinstance(v, float) for v in values]): |
|
dtype = 'DT_FLOAT' |
|
field = 'float_val' |
|
elif all([isinstance(v, int) for v in values]): |
|
dtype = 'DT_INT32' |
|
field = 'int_val' |
|
else: |
|
raise Exception('Wrong values types') |
|
|
|
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values) |
|
for value in values: |
|
msg += '%s: %s ' % (field, str(value)) |
|
return msg + '}' |
|
|
|
def addConstNode(name, values, graph_def): |
|
node = NodeDef() |
|
node.name = name |
|
node.op = 'Const' |
|
text_format.Merge(tensorMsg(values), node.attr["value"]) |
|
graph_def.node.extend([node])
|
|
|