mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 21:20:18 +08:00
dnn(caffe): fix net.input_dim handling in Caffe importer
This commit is contained in:
parent
a8c257cecb
commit
7ac7aca33b
@ -75,6 +75,17 @@ static cv::String toString(const T &v)
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
static inline
|
||||
MatShape parseBlobShape(const caffe::BlobShape& _input_shape)
|
||||
{
|
||||
MatShape shape;
|
||||
for (int i = 0; i < _input_shape.dim_size(); i++)
|
||||
{
|
||||
shape.push_back((int)_input_shape.dim(i));
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
class CaffeImporter
|
||||
{
|
||||
caffe::NetParameter net;
|
||||
@ -235,10 +246,7 @@ public:
|
||||
}
|
||||
else if (pbBlob.has_shape())
|
||||
{
|
||||
const caffe::BlobShape &_shape = pbBlob.shape();
|
||||
|
||||
for (int i = 0; i < _shape.dim_size(); i++)
|
||||
shape.push_back((int)_shape.dim(i));
|
||||
shape = parseBlobShape(pbBlob.shape());
|
||||
}
|
||||
else
|
||||
shape.resize(1, 1); // Is a scalar.
|
||||
@ -334,12 +342,49 @@ public:
|
||||
|
||||
//setup input layer names
|
||||
std::vector<String> netInputs(net.input_size());
|
||||
std::vector<MatShape> inp_shapes;
|
||||
{
|
||||
for (int inNum = 0; inNum < net.input_size(); inNum++)
|
||||
int net_input_size = net.input_size();
|
||||
for (int inNum = 0; inNum < net_input_size; inNum++)
|
||||
{
|
||||
addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
|
||||
netInputs[inNum] = net.input(inNum);
|
||||
}
|
||||
|
||||
if (net.input_dim_size() > 0) // deprecated in Caffe proto
|
||||
{
|
||||
int net_input_dim_size = net.input_dim_size();
|
||||
CV_Check(net_input_dim_size, net_input_dim_size % 4 == 0, "");
|
||||
CV_CheckEQ(net_input_dim_size, net_input_size * 4, "");
|
||||
for (int inp_id = 0; inp_id < net_input_size; inp_id++)
|
||||
{
|
||||
int dim = inp_id * 4;
|
||||
MatShape shape(4);
|
||||
shape[0] = net.input_dim(dim);
|
||||
shape[1] = net.input_dim(dim+1);
|
||||
shape[2] = net.input_dim(dim+2);
|
||||
shape[3] = net.input_dim(dim+3);
|
||||
inp_shapes.push_back(shape);
|
||||
}
|
||||
}
|
||||
else if (net.input_shape_size() > 0) // deprecated in Caffe proto
|
||||
{
|
||||
int net_input_shape_size = net.input_shape_size();
|
||||
CV_CheckEQ(net_input_shape_size, net_input_size, "");
|
||||
for (int inp_id = 0; inp_id < net_input_shape_size; inp_id++)
|
||||
{
|
||||
MatShape shape = parseBlobShape(net.input_shape(inp_id));
|
||||
inp_shapes.push_back(shape);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int inp_id = 0; inp_id < net_input_size; inp_id++)
|
||||
{
|
||||
MatShape shape; // empty
|
||||
inp_shapes.push_back(shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int li = 0; li < layersSize; li++)
|
||||
@ -364,6 +409,17 @@ public:
|
||||
addedBlobs.back().outNum = netInputs.size();
|
||||
netInputs.push_back(addedBlobs.back().name);
|
||||
}
|
||||
if (layer.has_input_param())
|
||||
{
|
||||
const caffe::InputParameter &inputParameter = layer.input_param();
|
||||
int input_shape_size = inputParameter.shape_size();
|
||||
CV_CheckEQ(input_shape_size, layer.top_size(), "");
|
||||
for (int inp_id = 0; inp_id < input_shape_size; inp_id++)
|
||||
{
|
||||
MatShape shape = parseBlobShape(inputParameter.shape(inp_id));
|
||||
inp_shapes.push_back(shape);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
else if (type == "BatchNorm")
|
||||
@ -424,35 +480,15 @@ public:
|
||||
}
|
||||
dstNet.setInputsNames(netInputs);
|
||||
|
||||
std::vector<MatShape> inp_shapes;
|
||||
if (net.input_shape_size() > 0 || (layersSize > 0 && net.layer(0).has_input_param() &&
|
||||
net.layer(0).input_param().shape_size() > 0)) {
|
||||
|
||||
int size = (net.input_shape_size() > 0) ? net.input_shape_size() :
|
||||
net.layer(0).input_param().shape_size();
|
||||
for (int inp_id = 0; inp_id < size; inp_id++)
|
||||
if (inp_shapes.size() > 0)
|
||||
{
|
||||
CV_CheckEQ(inp_shapes.size(), netInputs.size(), "");
|
||||
for (int inp_id = 0; inp_id < inp_shapes.size(); inp_id++)
|
||||
{
|
||||
const caffe::BlobShape &_input_shape = (net.input_shape_size() > 0) ?
|
||||
net.input_shape(inp_id) :
|
||||
net.layer(0).input_param().shape(inp_id);
|
||||
MatShape shape;
|
||||
for (int i = 0; i < _input_shape.dim_size(); i++) {
|
||||
shape.push_back((int)_input_shape.dim(i));
|
||||
}
|
||||
inp_shapes.push_back(shape);
|
||||
if (!inp_shapes[inp_id].empty())
|
||||
dstNet.setInput(Mat(inp_shapes[inp_id], CV_32F), netInputs[inp_id]);
|
||||
}
|
||||
}
|
||||
else if (net.input_dim_size() > 0) {
|
||||
MatShape shape;
|
||||
for (int dim = 0; dim < net.input_dim_size(); dim++) {
|
||||
shape.push_back(net.input_dim(dim));
|
||||
}
|
||||
inp_shapes.push_back(shape);
|
||||
}
|
||||
|
||||
for (int inp_id = 0; inp_id < inp_shapes.size(); inp_id++) {
|
||||
dstNet.setInput(Mat(inp_shapes[inp_id], CV_32F), netInputs[inp_id]);
|
||||
}
|
||||
|
||||
addedBlobs.clear();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user