mirror of
https://github.com/opencv/opencv.git
synced 2024-12-01 14:59:54 +08:00
26 lines
802 B
Python
26 lines
802 B
Python
import tensorflow as tf
|
|
from tensorflow.core.framework.node_def_pb2 import NodeDef
|
|
from google.protobuf import text_format
|
|
|
|
def tensorMsg(values):
|
|
if all([isinstance(v, float) for v in values]):
|
|
dtype = 'DT_FLOAT'
|
|
field = 'float_val'
|
|
elif all([isinstance(v, int) for v in values]):
|
|
dtype = 'DT_INT32'
|
|
field = 'int_val'
|
|
else:
|
|
raise Exception('Wrong values types')
|
|
|
|
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
|
|
for value in values:
|
|
msg += '%s: %s ' % (field, str(value))
|
|
return msg + '}'
|
|
|
|
def addConstNode(name, values, graph_def):
|
|
node = NodeDef()
|
|
node.name = name
|
|
node.op = 'Const'
|
|
text_format.Merge(tensorMsg(values), node.attr["value"])
|
|
graph_def.node.extend([node])
|