mirror of
https://github.com/opencv/opencv.git
synced 2024-12-11 22:59:16 +08:00
004a1cd64a
* Add a parameter labels to command line * default value * samples: caffe_googlenet.cpp minor refactoring
182 lines
6.4 KiB
C++
182 lines
6.4 KiB
C++
/**M///////////////////////////////////////////////////////////////////////////////////////
|
|
//
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
//
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
// If you do not agree to this license, do not download, install,
|
|
// copy or use the software.
|
|
//
|
|
//
|
|
// License Agreement
|
|
// For Open Source Computer Vision Library
|
|
//
|
|
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
|
|
// Third party copyrights are property of their respective owners.
|
|
//
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
// are permitted provided that the following conditions are met:
|
|
//
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
// this list of conditions and the following disclaimer.
|
|
//
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
// and/or other materials provided with the distribution.
|
|
//
|
|
// * The name of the copyright holders may not be used to endorse or promote products
|
|
// derived from this software without specific prior written permission.
|
|
//
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
// (including, but not limited to, procurement of substitute goods or services;
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
// or tort (including negligence or otherwise) arising in any way out of
|
|
// the use of this software, even if advised of the possibility of such damage.
|
|
//
|
|
//M*/
|
|
#include <opencv2/dnn.hpp>
|
|
#include <opencv2/imgproc.hpp>
|
|
#include <opencv2/highgui.hpp>
|
|
#include <opencv2/core/utils/trace.hpp>
|
|
using namespace cv;
|
|
using namespace cv::dnn;
|
|
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <cstdlib>
|
|
using namespace std;
|
|
|
|
/* Find best class for the blob (i. e. class with maximal probability) */
|
|
static void getMaxClass(const Mat &probBlob, int *classId, double *classProb)
|
|
{
|
|
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix
|
|
Point classNumber;
|
|
|
|
minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
|
|
*classId = classNumber.x;
|
|
}
|
|
|
|
static std::vector<String> readClassNames(const char *filename )
|
|
{
|
|
std::vector<String> classNames;
|
|
|
|
std::ifstream fp(filename);
|
|
if (!fp.is_open())
|
|
{
|
|
std::cerr << "File with classes labels not found: " << filename << std::endl;
|
|
exit(-1);
|
|
}
|
|
|
|
std::string name;
|
|
while (!fp.eof())
|
|
{
|
|
std::getline(fp, name);
|
|
if (name.length())
|
|
classNames.push_back( name.substr(name.find(' ')+1) );
|
|
}
|
|
|
|
fp.close();
|
|
return classNames;
|
|
}
|
|
|
|
const char* params
|
|
= "{ help | false | Sample app for loading googlenet model }"
|
|
"{ proto | bvlc_googlenet.prototxt | model configuration }"
|
|
"{ model | bvlc_googlenet.caffemodel | model weights }"
|
|
"{ label | synset_words.txt | names of ILSVRC2012 classes }"
|
|
"{ image | space_shuttle.jpg | path to image file }"
|
|
"{ opencl | false | enable OpenCL }"
|
|
;
|
|
|
|
int main(int argc, char **argv)
|
|
{
|
|
CV_TRACE_FUNCTION();
|
|
|
|
CommandLineParser parser(argc, argv, params);
|
|
|
|
if (parser.get<bool>("help"))
|
|
{
|
|
parser.printMessage();
|
|
return 0;
|
|
}
|
|
|
|
String modelTxt = parser.get<string>("proto");
|
|
String modelBin = parser.get<string>("model");
|
|
String imageFile = parser.get<String>("image");
|
|
String classNameFile = parser.get<String>("label");
|
|
|
|
Net net;
|
|
try {
|
|
//! [Read and initialize network]
|
|
net = dnn::readNetFromCaffe(modelTxt, modelBin);
|
|
//! [Read and initialize network]
|
|
}
|
|
catch (const cv::Exception& e) {
|
|
std::cerr << "Exception: " << e.what() << std::endl;
|
|
//! [Check that network was read successfully]
|
|
if (net.empty())
|
|
{
|
|
std::cerr << "Can't load network by using the following files: " << std::endl;
|
|
std::cerr << "prototxt: " << modelTxt << std::endl;
|
|
std::cerr << "caffemodel: " << modelBin << std::endl;
|
|
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
|
|
std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
|
|
exit(-1);
|
|
}
|
|
//! [Check that network was read successfully]
|
|
}
|
|
|
|
if (parser.get<bool>("opencl"))
|
|
{
|
|
net.setPreferableTarget(DNN_TARGET_OPENCL);
|
|
}
|
|
|
|
//! [Prepare blob]
|
|
Mat img = imread(imageFile);
|
|
if (img.empty())
|
|
{
|
|
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
|
|
exit(-1);
|
|
}
|
|
|
|
//GoogLeNet accepts only 224x224 BGR-images
|
|
Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224),
|
|
Scalar(104, 117, 123), false); //Convert Mat to batch of images
|
|
//! [Prepare blob]
|
|
net.setInput(inputBlob, "data"); //set the network input
|
|
Mat prob = net.forward("prob"); //compute output
|
|
|
|
cv::TickMeter t;
|
|
for (int i = 0; i < 10; i++)
|
|
{
|
|
CV_TRACE_REGION("forward");
|
|
//! [Set input blob]
|
|
net.setInput(inputBlob, "data"); //set the network input
|
|
//! [Set input blob]
|
|
t.start();
|
|
//! [Make forward pass]
|
|
prob = net.forward("prob"); //compute output
|
|
//! [Make forward pass]
|
|
t.stop();
|
|
}
|
|
|
|
//! [Gather output]
|
|
int classId;
|
|
double classProb;
|
|
getMaxClass(prob, &classId, &classProb);//find the best class
|
|
//! [Gather output]
|
|
|
|
//! [Print results]
|
|
std::vector<String> classNames = readClassNames(classNameFile.c_str());
|
|
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
|
|
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
|
|
//! [Print results]
|
|
std::cout << "Time: " << (double)t.getTimeMilli() / t.getCounter() << " ms (average from " << t.getCounter() << " iterations)" << std::endl;
|
|
|
|
return 0;
|
|
} //main
|