def tokenize(s): tokens = [] token = "" isString = False isComment = False for symbol in s: isComment = (isComment and symbol != '\n') or (not isString and symbol == '#') if isComment: continue if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \ symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \ symbol == ',': if (symbol == '\"' or symbol == '\'') and isString: tokens.append(token) token = "" else: if isString: token += symbol elif token: tokens.append(token) token = "" isString = (symbol == '\"' or symbol == '\'') ^ isString; elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']': if token: tokens.append(token) token = "" tokens.append(symbol) else: token += symbol if token: tokens.append(token) return tokens def parseMessage(tokens, idx): msg = {} assert(tokens[idx] == '{') isArray = False while True: if not isArray: idx += 1 if idx < len(tokens): fieldName = tokens[idx] else: return None if fieldName == '}': break idx += 1 fieldValue = tokens[idx] if fieldValue == '{': embeddedMsg, idx = parseMessage(tokens, idx) if fieldName in msg: msg[fieldName].append(embeddedMsg) else: msg[fieldName] = [embeddedMsg] elif fieldValue == '[': isArray = True elif fieldValue == ']': isArray = False else: if fieldName in msg: msg[fieldName].append(fieldValue) else: msg[fieldName] = [fieldValue] return msg, idx def readTextMessage(filePath): with open(filePath, 'rt') as f: content = f.read() tokens = tokenize('{' + content + '}') msg = parseMessage(tokens, 0) return msg[0] if msg else {} def listToTensor(values): if all([isinstance(v, float) for v in values]): dtype = 'DT_FLOAT' field = 'float_val' elif all([isinstance(v, int) for v in values]): dtype = 'DT_INT32' field = 'int_val' else: raise Exception('Wrong values types') msg = { 'tensor': { 'dtype': dtype, 'tensor_shape': { 'dim': { 'size': len(values) } } } } msg['tensor'][field] = values return msg def addConstNode(name, values, graph_def): node = NodeDef() node.name = name node.op = 'Const' node.addAttr('value', values) graph_def.node.extend([node]) def addSlice(inp, out, begins, sizes, graph_def): beginsNode = NodeDef() beginsNode.name = out + '/begins' beginsNode.op = 'Const' beginsNode.addAttr('value', begins) graph_def.node.extend([beginsNode]) sizesNode = NodeDef() sizesNode.name = out + '/sizes' sizesNode.op = 'Const' sizesNode.addAttr('value', sizes) graph_def.node.extend([sizesNode]) sliced = NodeDef() sliced.name = out sliced.op = 'Slice' sliced.input.append(inp) sliced.input.append(beginsNode.name) sliced.input.append(sizesNode.name) graph_def.node.extend([sliced]) def addReshape(inp, out, shape, graph_def): shapeNode = NodeDef() shapeNode.name = out + '/shape' shapeNode.op = 'Const' shapeNode.addAttr('value', shape) graph_def.node.extend([shapeNode]) reshape = NodeDef() reshape.name = out reshape.op = 'Reshape' reshape.input.append(inp) reshape.input.append(shapeNode.name) graph_def.node.extend([reshape]) def addSoftMax(inp, out, graph_def): softmax = NodeDef() softmax.name = out softmax.op = 'Softmax' softmax.addAttr('axis', -1) softmax.input.append(inp) graph_def.node.extend([softmax]) def addFlatten(inp, out, graph_def): flatten = NodeDef() flatten.name = out flatten.op = 'Flatten' flatten.input.append(inp) graph_def.node.extend([flatten]) class NodeDef: def __init__(self): self.input = [] self.name = "" self.op = "" self.attr = {} def addAttr(self, key, value): assert(not key in self.attr) if isinstance(value, bool): self.attr[key] = {'b': value} elif isinstance(value, int): self.attr[key] = {'i': value} elif isinstance(value, float): self.attr[key] = {'f': value} elif isinstance(value, str): self.attr[key] = {'s': value} elif isinstance(value, list): self.attr[key] = listToTensor(value) else: raise Exception('Unknown type of attribute ' + key) def Clear(self): self.input = [] self.name = "" self.op = "" self.attr = {} class GraphDef: def __init__(self): self.node = [] def save(self, filePath): with open(filePath, 'wt') as f: def printAttr(d, indent): indent = ' ' * indent for key, value in sorted(d.items(), key=lambda x:x[0].lower()): value = value if isinstance(value, list) else [value] for v in value: if isinstance(v, dict): f.write(indent + key + ' {\n') printAttr(v, len(indent) + 2) f.write(indent + '}\n') else: isString = False if isinstance(v, str) and not v.startswith('DT_'): try: float(v) except: isString = True if isinstance(v, bool): printed = 'true' if v else 'false' elif v == 'true' or v == 'false': printed = 'true' if v == 'true' else 'false' elif isString: printed = '\"%s\"' % v else: printed = str(v) f.write(indent + key + ': ' + printed + '\n') for node in self.node: f.write('node {\n') f.write(' name: \"%s\"\n' % node.name) f.write(' op: \"%s\"\n' % node.op) for inp in node.input: f.write(' input: \"%s\"\n' % inp) for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()): f.write(' attr {\n') f.write(' key: \"%s\"\n' % key) f.write(' value {\n') printAttr(value, 6) f.write(' }\n') f.write(' }\n') f.write('}\n') def parseTextGraph(filePath): msg = readTextMessage(filePath) graph = GraphDef() for node in msg['node']: graphNode = NodeDef() graphNode.name = node['name'][0] graphNode.op = node['op'][0] graphNode.input = node['input'] if 'input' in node else [] if 'attr' in node: for attr in node['attr']: graphNode.attr[attr['key'][0]] = attr['value'][0] graph.node.append(graphNode) return graph # Removes Identity nodes def removeIdentity(graph_def): identities = {} for node in graph_def.node: if node.op == 'Identity': identities[node.name] = node.input[0] graph_def.node.remove(node) for node in graph_def.node: for i in range(len(node.input)): if node.input[i] in identities: node.input[i] = identities[node.input[i]] def removeUnusedNodesAndAttrs(to_remove, graph_def): unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training', 'Tpaddings'] removedNodes = [] for i in reversed(range(len(graph_def.node))): op = graph_def.node[i].op name = graph_def.node[i].name if op == 'Const' or to_remove(name, op): if op != 'Const': removedNodes.append(name) del graph_def.node[i] else: for attr in unusedAttrs: if attr in graph_def.node[i].attr: del graph_def.node[i].attr[attr] # Remove references to removed nodes except Const nodes. for node in graph_def.node: for i in reversed(range(len(node.input))): if node.input[i] in removedNodes: del node.input[i] def writeTextGraph(modelPath, outputPath, outNodes): try: import cv2 as cv cv.dnn.writeTextGraph(modelPath, outputPath) except: import tensorflow as tf from tensorflow.tools.graph_transforms import TransformGraph with tf.gfile.FastGFile(modelPath, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order']) for node in graph_def.node: if node.op == 'Const': if 'value' in node.attr: del node.attr['value'] tf.train.write_graph(graph_def, "", outputPath, as_text=True)