diff --git a/samples/dnn/tf_text_graph_ssd.py b/samples/dnn/tf_text_graph_ssd.py new file mode 100644 index 0000000000..50fdc2ca3a --- /dev/null +++ b/samples/dnn/tf_text_graph_ssd.py @@ -0,0 +1,255 @@ +# This file is a part of OpenCV project. +# It is a subject to the license terms in the LICENSE file found in the top-level directory +# of this distribution and at http://opencv.org/license.html. +# +# Copyright (C) 2018, Intel Corporation, all rights reserved. +# Third party copyrights are property of their respective owners. +# +# Use this script to get the text graph representation (.pbtxt) of SSD-based +# deep learning network trained in TensorFlow Object Detection API. +# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function. +# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API +import tensorflow as tf +import argparse +from math import sqrt +from tensorflow.core.framework.node_def_pb2 import NodeDef +from google.protobuf import text_format + +parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' + 'SSD model from TensorFlow Object Detection API. ' + 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') +parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') +parser.add_argument('--output', required=True, help='Path to output text graph.') +parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.') +parser.add_argument('--min_scale', default=0.2, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.') +parser.add_argument('--max_scale', default=0.95, type=float, help='Hyper-parameter of ssd_anchor_generator from config file.') +parser.add_argument('--num_layers', default=6, type=int, help='Hyper-parameter of ssd_anchor_generator from config file.') +parser.add_argument('--aspect_ratios', default=[1.0, 2.0, 0.5, 3.0, 0.333], type=float, nargs='+', + help='Hyper-parameter of ssd_anchor_generator from config file.') +args = parser.parse_args() + +# Nodes that should be kept. +keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm', + 'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool'] + +# Nodes attributes that could be removed because they are not used during import. +unusedAttrs = ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', + 'Index', 'Tperm', 'is_training', 'Tpaddings'] + +# Node with which prefixes should be removed +prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/') + +# Read the graph. +with tf.gfile.FastGFile(args.input, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + +def getUnconnectedNodes(): + unconnected = [] + for node in graph_def.node: + unconnected.append(node.name) + for inp in node.input: + if inp in unconnected: + unconnected.remove(inp) + return unconnected + +removedNodes = [] + +# Detect unfused batch normalization nodes and fuse them. +def fuse_batch_normalization(): + pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add'] + candidates = [] + + for node in graph_def.node: + if node.op == pattern[len(candidates)]: + candidates.append(node) + else: + candidates = [] + + if len(candidates) == len(pattern): + inp = candidates[3].input[0] + gamma = candidates[2].input[1] + beta = candidates[5].input[0] + moving_mean = candidates[4].input[0] + moving_variance = candidates[0].input[0] + + name = node.name + node.Clear() + node.name = name + node.op = 'FusedBatchNorm' + node.input.append(inp) + node.input.append(gamma) + node.input.append(beta) + node.input.append(moving_mean) + node.input.append(moving_variance) + text_format.Merge('f: 0.001', node.attr["epsilon"]) + + for candidate in candidates[:-1]: + graph_def.node.remove(candidate) + candidates = [] + +fuse_batch_normalization() + +# Removes Identity nodes +def removeIdentity(): + identities = {} + for node in graph_def.node: + if node.op == 'Identity': + identities[node.name] = node.input[0] + + for node in graph_def.node: + for i in range(len(node.input)): + if node.input[i] in identities: + node.input[i] = identities[node.input[i]] + +removeIdentity() + +# Remove extra nodes and attributes. +for i in reversed(range(len(graph_def.node))): + op = graph_def.node[i].op + name = graph_def.node[i].name + + if (not op in keepOps) or name.startswith(prefixesToRemove): + if op != 'Const': + removedNodes.append(name) + + del graph_def.node[i] + else: + for attr in unusedAttrs: + if attr in graph_def.node[i].attr: + del graph_def.node[i].attr[attr] + +# Remove references to removed nodes except Const nodes. +for node in graph_def.node: + for i in reversed(range(len(node.input))): + if node.input[i] in removedNodes: + del node.input[i] + +# Connect input node to the first layer +assert(graph_def.node[0].op == 'Placeholder') +# assert(graph_def.node[1].op == 'Conv2D') +weights = graph_def.node[1].input[0] +for i in range(len(graph_def.node[1].input)): + graph_def.node[1].input.pop() +graph_def.node[1].input.append(graph_def.node[0].name) +graph_def.node[1].input.append(weights) + +# Create SSD postprocessing head ############################################### + +# Concatenate predictions of classes, predictions of bounding boxes and proposals. + +concatAxis = NodeDef() +concatAxis.name = 'concat/axis_flatten' +concatAxis.op = 'Const' +text_format.Merge( +'tensor {' +' dtype: DT_INT32' +' tensor_shape { }' +' int_val: -1' +'}', concatAxis.attr["value"]) +graph_def.node.extend([concatAxis]) + +def addConcatNode(name, inputs): + concat = NodeDef() + concat.name = name + concat.op = 'ConcatV2' + for inp in inputs: + concat.input.append(inp) + concat.input.append(concatAxis.name) + graph_def.node.extend([concat]) + +for label in ['ClassPredictor', 'BoxEncodingPredictor']: + concatInputs = [] + for i in range(args.num_layers): + # Flatten predictions + flatten = NodeDef() + inpName = 'BoxPredictor_%d/%s/BiasAdd' % (i, label) + flatten.input.append(inpName) + flatten.name = inpName + '/Flatten' + flatten.op = 'Flatten' + + concatInputs.append(flatten.name) + graph_def.node.extend([flatten]) + addConcatNode('%s/concat' % label, concatInputs) + +# Add layers that generate anchors (bounding boxes proposals). +scales = [args.min_scale + (args.max_scale - args.min_scale) * i / (args.num_layers - 1) + for i in range(args.num_layers)] + [1.0] + +def tensorMsg(values): + msg = 'tensor { dtype: DT_FLOAT tensor_shape { dim { size: %d } }' % len(values) + for value in values: + msg += 'float_val: %f ' % value + return msg + '}' + +priorBoxes = [] +for i in range(args.num_layers): + priorBox = NodeDef() + priorBox.name = 'PriorBox_%d' % i + priorBox.op = 'PriorBox' + priorBox.input.append('BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % i) + priorBox.input.append(graph_def.node[0].name) # image_tensor + + text_format.Merge('b: false', priorBox.attr["flip"]) + text_format.Merge('b: false', priorBox.attr["clip"]) + text_format.Merge('b: true', priorBox.attr["normalized_bbox"]) + + if i == 0: + widths = [args.min_scale * 0.5, args.min_scale * sqrt(2.0), args.min_scale * sqrt(0.5)] + heights = [args.min_scale * 0.5, args.min_scale / sqrt(2.0), args.min_scale / sqrt(0.5)] + else: + widths = [scales[i] * sqrt(ar) for ar in args.aspect_ratios] + heights = [scales[i] / sqrt(ar) for ar in args.aspect_ratios] + + widths += [sqrt(scales[i] * scales[i + 1])] + heights += [sqrt(scales[i] * scales[i + 1])] + text_format.Merge(tensorMsg(widths), priorBox.attr["width"]) + text_format.Merge(tensorMsg(heights), priorBox.attr["height"]) + text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), priorBox.attr["variance"]) + + graph_def.node.extend([priorBox]) + priorBoxes.append(priorBox.name) + +addConcatNode('PriorBox/concat', priorBoxes) + +# Sigmoid for classes predictions and DetectionOutput layer +sigmoid = NodeDef() +sigmoid.name = 'ClassPredictor/concat/sigmoid' +sigmoid.op = 'Sigmoid' +sigmoid.input.append('ClassPredictor/concat') +graph_def.node.extend([sigmoid]) + +detectionOut = NodeDef() +detectionOut.name = 'detection_out' +detectionOut.op = 'DetectionOutput' + +detectionOut.input.append('BoxEncodingPredictor/concat') +detectionOut.input.append(sigmoid.name) +detectionOut.input.append('PriorBox/concat') + +text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['num_classes']) +text_format.Merge('b: true', detectionOut.attr['share_location']) +text_format.Merge('i: 0', detectionOut.attr['background_label_id']) +text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold']) +text_format.Merge('i: 100', detectionOut.attr['top_k']) +text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) +text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) +text_format.Merge('f: 0.01', detectionOut.attr['confidence_threshold']) +text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) + +graph_def.node.extend([detectionOut]) + +while True: + unconnectedNodes = getUnconnectedNodes() + unconnectedNodes.remove(detectionOut.name) + if not unconnectedNodes: + break + + for name in unconnectedNodes: + for i in range(len(graph_def.node)): + if graph_def.node[i].name == name: + del graph_def.node[i] + break + +# Save as text. +tf.train.write_graph(graph_def, "", args.output, as_text=True)