Merge pull request #9963 from dkurt:fix_caffe_shrinker

This commit is contained in:
Vadim Pisarevsky 2017-10-31 12:27:18 +00:00
commit bc348eb8ab
2 changed files with 18 additions and 3 deletions

View File

@ -726,13 +726,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* @param src Path to origin model from Caffe framework contains single
* precision floating point weights (usually has `.caffemodel` extension).
* @param dst Path to destination model with updated weights.
* @param layersTypes Set of layers types which parameters will be converted.
* By default, converts only Convolutional and Fully-Connected layers'
* weights.
*
* @note Shrinked model has no origin float32 weights so it can't be used
* in origin Caffe framework anymore. However the structure of data
* is taken from NVidia's Caffe fork: https://github.com/NVIDIA/caffe.
* So the resulting model may be used there.
*/
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst);
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst,
const std::vector<String>& layersTypes = std::vector<String>());
/** @brief Performs non maximum suppression given boxes and corresponding scores.

View File

@ -17,16 +17,27 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
#ifdef HAVE_PROTOBUF
void shrinkCaffeModel(const String& src, const String& dst)
void shrinkCaffeModel(const String& src, const String& dst, const std::vector<String>& layersTypes)
{
CV_TRACE_FUNCTION();
std::vector<String> types(layersTypes);
if (types.empty())
{
types.push_back("Convolution");
types.push_back("InnerProduct");
}
caffe::NetParameter net;
ReadNetParamsFromBinaryFileOrDie(src.c_str(), &net);
for (int i = 0; i < net.layer_size(); ++i)
{
caffe::LayerParameter* lp = net.mutable_layer(i);
if (std::find(types.begin(), types.end(), lp->type()) == types.end())
{
continue;
}
for (int j = 0; j < lp->blobs_size(); ++j)
{
caffe::BlobProto* blob = lp->mutable_blobs(j);
@ -54,7 +65,7 @@ void shrinkCaffeModel(const String& src, const String& dst)
#else
void shrinkCaffeModel(const String& src, const String& dst)
void shrinkCaffeModel(const String& src, const String& dst, const std::vector<String>& types)
{
CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from Caffe models");
}