mirror of
https://github.com/opencv/opencv.git
synced 2025-01-19 15:04:01 +08:00
Merge pull request #15956 from lorenzolightsgdwarf:dnn_fix_tf_ssd
This commit is contained in:
commit
a2aa8db5a9
@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
||||
|
||||
# Add layers that generate anchors (bounding boxes proposals).
|
||||
priorBoxes = []
|
||||
boxCoder = config['box_coder'][0]
|
||||
fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
|
||||
boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
|
||||
for i in range(num_layers):
|
||||
priorBox = NodeDef()
|
||||
priorBox.name = 'PriorBox_%d' % i
|
||||
@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
||||
|
||||
priorBox.addAttr('width', widths)
|
||||
priorBox.addAttr('height', heights)
|
||||
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
|
||||
priorBox.addAttr('variance', boxCoderVariance)
|
||||
|
||||
graph_def.node.extend([priorBox])
|
||||
priorBoxes.append(priorBox.name)
|
||||
@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
||||
detectionOut.addAttr('num_classes', num_classes + 1)
|
||||
detectionOut.addAttr('share_location', True)
|
||||
detectionOut.addAttr('background_label_id', 0)
|
||||
detectionOut.addAttr('nms_threshold', 0.6)
|
||||
detectionOut.addAttr('top_k', 100)
|
||||
|
||||
postProcessing = config['post_processing'][0]
|
||||
batchNMS = postProcessing['batch_non_max_suppression'][0]
|
||||
|
||||
if 'iou_threshold' in batchNMS:
|
||||
detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
|
||||
else:
|
||||
detectionOut.addAttr('nms_threshold', 0.6)
|
||||
|
||||
if 'score_threshold' in batchNMS:
|
||||
detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
|
||||
else:
|
||||
detectionOut.addAttr('confidence_threshold', 0.01)
|
||||
|
||||
if 'max_detections_per_class' in batchNMS:
|
||||
detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
|
||||
else:
|
||||
detectionOut.addAttr('top_k', 100)
|
||||
|
||||
if 'max_total_detections' in batchNMS:
|
||||
detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
|
||||
else:
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
detectionOut.addAttr('confidence_threshold', 0.01)
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user