mirror of
https://github.com/opencv/opencv.git
synced 2025-06-18 08:05:23 +08:00
Determine SSD input shape
This commit is contained in:
parent
b2464e3379
commit
c5a2d28367
@ -234,6 +234,12 @@ def createSSDGraph(modelPath, configPath, outputPath):
|
|||||||
|
|
||||||
# Connect input node to the first layer
|
# Connect input node to the first layer
|
||||||
assert(graph_def.node[0].op == 'Placeholder')
|
assert(graph_def.node[0].op == 'Placeholder')
|
||||||
|
try:
|
||||||
|
input_shape = graph_def.node[0].attr['shape']['shape'][0]['dim']
|
||||||
|
input_shape[1]['size'] = image_height
|
||||||
|
input_shape[2]['size'] = image_width
|
||||||
|
except:
|
||||||
|
print("Input shapes are undefined")
|
||||||
# assert(graph_def.node[1].op == 'Conv2D')
|
# assert(graph_def.node[1].op == 'Conv2D')
|
||||||
weights = graph_def.node[1].input[-1]
|
weights = graph_def.node[1].input[-1]
|
||||||
for i in range(len(graph_def.node[1].input)):
|
for i in range(len(graph_def.node[1].input)):
|
||||||
|
Loading…
Reference in New Issue
Block a user