Merge pull request #15956 from lorenzolightsgdwarf:dnn_fix_tf_ssd

This commit is contained in:
Alexander Alekhin 2019-11-20 15:43:32 +00:00
commit a2aa8db5a9

View File

@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
# Add layers that generate anchors (bounding boxes proposals).
priorBoxes = []
boxCoder = config['box_coder'][0]
fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
for i in range(num_layers):
priorBox = NodeDef()
priorBox.name = 'PriorBox_%d' % i
@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
priorBox.addAttr('width', widths)
priorBox.addAttr('height', heights)
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
priorBox.addAttr('variance', boxCoderVariance)
graph_def.node.extend([priorBox])
priorBoxes.append(priorBox.name)
@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath):
detectionOut.addAttr('num_classes', num_classes + 1)
detectionOut.addAttr('share_location', True)
detectionOut.addAttr('background_label_id', 0)
detectionOut.addAttr('nms_threshold', 0.6)
detectionOut.addAttr('top_k', 100)
postProcessing = config['post_processing'][0]
batchNMS = postProcessing['batch_non_max_suppression'][0]
if 'iou_threshold' in batchNMS:
detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
else:
detectionOut.addAttr('nms_threshold', 0.6)
if 'score_threshold' in batchNMS:
detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
else:
detectionOut.addAttr('confidence_threshold', 0.01)
if 'max_detections_per_class' in batchNMS:
detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
else:
detectionOut.addAttr('top_k', 100)
if 'max_total_detections' in batchNMS:
detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
else:
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('code_type', "CENTER_SIZE")
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('confidence_threshold', 0.01)
graph_def.node.extend([detectionOut])