2018-03-03 21:43:21 +08:00
# include <fstream>
# include <sstream>
2021-11-24 05:15:31 +08:00
# include <iostream>
2018-03-03 21:43:21 +08:00
# include <opencv2/dnn.hpp>
2018-03-04 00:29:37 +08:00
# include <opencv2/imgproc.hpp>
# include <opencv2/highgui.hpp>
2018-03-03 21:43:21 +08:00
2018-09-20 22:59:04 +08:00
# include "common.hpp"
2018-03-03 21:43:21 +08:00
using namespace cv ;
2024-08-06 14:16:11 +08:00
using namespace std ;
2018-03-03 21:43:21 +08:00
using namespace dnn ;
2024-08-06 14:16:11 +08:00
const string about =
" Use this script to run a classification model on a camera stream, video, image or image list (i.e. .xml or .yaml containing image lists) \n \n "
" Firstly, download required models using `download_models.py` (if not already done). Set environment variable OPENCV_DOWNLOAD_CACHE_DIR to specify where models should be downloaded. Also, point OPENCV_SAMPLES_DATA_PATH to opencv/samples/data. \n "
" To run: \n "
" \t ./example_dnn_classification model_name --input=path/to/your/input/image/or/video (don't give --input flag if want to use device camera) \n "
" Sample command: \n "
" \t ./example_dnn_classification resnet --input=$OPENCV_SAMPLES_DATA_PATH/baboon.jpg \n "
" \t ./example_dnn_classification squeezenet \n "
" Model path can also be specified using --model argument. "
" Use imagelist_creator to create the xml or yaml list \n " ;
const string param_keys =
" { help h | | Print help message. } "
" { @alias | | An alias name of model to extract preprocessing parameters from models.yml file. } "
" { zoo | ../dnn/models.yml | An optional path to file with preprocessing parameters } "
" { input i | | Path to input image or video file. Skip this argument to capture frames from a camera.} "
" { imglist | | Pass this flag if image list (i.e. .xml or .yaml) file is passed} "
" { crop | false | Preprocess input image by center cropping.} "
//"{ labels | | Path to the text file with labels for detected objects.}"
" { model | | Path to the model file.} " ;
const string backend_keys = format (
" { backend | default | Choose one of computation backends: "
" default: automatically (by default), "
" openvino: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
" opencv: OpenCV implementation, "
" vkcom: VKCOM, "
" cuda: CUDA, "
" webnn: WebNN } " ) ;
const string target_keys = format (
" { target | cpu | Choose one of target computation devices: "
" cpu: CPU target (by default), "
" opencl: OpenCL, "
" opencl_fp16: OpenCL fp16 (half-float precision), "
" vpu: VPU, "
" vulkan: Vulkan, "
" cuda: CUDA, "
" cuda_fp16: CUDA fp16 (half-float preprocess) } " ) ;
string keys = param_keys + backend_keys + target_keys ;
vector < string > classes ;
static bool readStringList ( const string & filename , vector < string > & l )
{
l . resize ( 0 ) ;
FileStorage fs ( filename , FileStorage : : READ ) ;
if ( ! fs . isOpened ( ) )
return false ;
size_t dir_pos = filename . rfind ( ' / ' ) ;
if ( dir_pos = = string : : npos )
dir_pos = filename . rfind ( ' \\ ' ) ;
FileNode n = fs . getFirstTopLevelNode ( ) ;
if ( n . type ( ) ! = FileNode : : SEQ )
return false ;
FileNodeIterator it = n . begin ( ) , it_end = n . end ( ) ;
for ( ; it ! = it_end ; + + it )
{
string fname = ( string ) * it ;
if ( dir_pos ! = string : : npos )
{
string fpath = samples : : findFile ( filename . substr ( 0 , dir_pos + 1 ) + fname , false ) ;
if ( fpath . empty ( ) )
{
fpath = samples : : findFile ( fname ) ;
}
fname = fpath ;
}
else
{
fname = samples : : findFile ( fname ) ;
}
l . push_back ( fname ) ;
}
return true ;
}
2018-03-03 21:43:21 +08:00
int main ( int argc , char * * argv )
{
CommandLineParser parser ( argc , argv , keys ) ;
2018-09-20 22:59:04 +08:00
2024-08-06 14:16:11 +08:00
if ( ! parser . has ( " @alias " ) | | parser . has ( " help " ) )
{
cout < < about < < endl ;
parser . printMessage ( ) ;
return - 1 ;
}
const string modelName = parser . get < String > ( " @alias " ) ;
const string zooFile = findFile ( parser . get < String > ( " zoo " ) ) ;
2018-09-20 22:59:04 +08:00
keys + = genPreprocArguments ( modelName , zooFile ) ;
parser = CommandLineParser ( argc , argv , keys ) ;
2024-08-06 14:16:11 +08:00
parser . about ( about ) ;
2018-03-03 21:43:21 +08:00
if ( argc = = 1 | | parser . has ( " help " ) )
{
parser . printMessage ( ) ;
return 0 ;
}
2024-08-06 14:16:11 +08:00
String sha1 = parser . get < String > ( " sha1 " ) ;
2018-03-03 21:43:21 +08:00
float scale = parser . get < float > ( " scale " ) ;
2018-03-07 00:29:23 +08:00
Scalar mean = parser . get < Scalar > ( " mean " ) ;
2021-01-26 19:06:15 +08:00
Scalar std = parser . get < Scalar > ( " std " ) ;
2018-03-03 21:43:21 +08:00
bool swapRB = parser . get < bool > ( " rgb " ) ;
2021-01-26 19:06:15 +08:00
bool crop = parser . get < bool > ( " crop " ) ;
2018-03-03 21:43:21 +08:00
int inpWidth = parser . get < int > ( " width " ) ;
int inpHeight = parser . get < int > ( " height " ) ;
2024-08-06 14:16:11 +08:00
String model = findModel ( parser . get < String > ( " model " ) , sha1 ) ;
String backend = parser . get < String > ( " backend " ) ;
String target = parser . get < String > ( " target " ) ;
bool isImgList = parser . has ( " imglist " ) ;
// Open file with labels.
string labels_filename = parser . get < String > ( " labels " ) ;
string file = findFile ( labels_filename ) ;
ifstream ifs ( file . c_str ( ) ) ;
if ( ! ifs . is_open ( ) ) {
cout < < " File " < < file < < " not found " ;
exit ( 1 ) ;
}
string line ;
while ( getline ( ifs , line ) )
2018-03-03 21:43:21 +08:00
{
2024-08-06 14:16:11 +08:00
classes . push_back ( line ) ;
2018-03-03 21:43:21 +08:00
}
2018-08-15 19:55:47 +08:00
if ( ! parser . check ( ) )
{
parser . printErrors ( ) ;
return 1 ;
}
CV_Assert ( ! model . empty ( ) ) ;
2018-03-04 00:29:37 +08:00
//! [Read and initialize network]
2024-08-06 14:16:11 +08:00
Net net = readNetFromONNX ( model ) ;
net . setPreferableBackend ( getBackendID ( backend ) ) ;
net . setPreferableTarget ( getTargetID ( target ) ) ;
2018-03-04 00:29:37 +08:00
//! [Read and initialize network]
2018-03-03 21:43:21 +08:00
// Create a window
static const std : : string kWinName = " Deep learning image classification in OpenCV " ;
namedWindow ( kWinName , WINDOW_NORMAL ) ;
2024-08-06 14:16:11 +08:00
//Create FontFace for putText
FontFace sans ( " sans " ) ;
2018-03-04 00:29:37 +08:00
//! [Open a video file or an image file or a camera stream]
2018-03-03 21:43:21 +08:00
VideoCapture cap ;
2024-08-06 14:16:11 +08:00
vector < string > imageList ;
size_t currentImageIndex = 0 ;
if ( parser . has ( " input " ) ) {
string input = findFile ( parser . get < String > ( " input " ) ) ;
if ( isImgList ) {
bool check = readStringList ( samples : : findFile ( input ) , imageList ) ;
if ( imageList . empty ( ) | | ! check ) {
cout < < " Error: No images found or the provided file is not a valid .yaml or .xml file. " < < endl ;
return - 1 ;
}
} else {
// Input is not a directory, try to open as video or image
cap . open ( input ) ;
if ( ! cap . isOpened ( ) ) {
cout < < " Failed to open the input. " < < endl ;
return - 1 ;
}
}
} else {
cap . open ( 0 ) ; // Open default camera
}
2018-03-04 00:29:37 +08:00
//! [Open a video file or an image file or a camera stream]
2018-03-03 21:43:21 +08:00
Mat frame , blob ;
2024-08-06 14:16:11 +08:00
for ( ; ; )
2018-03-03 21:43:21 +08:00
{
2024-08-06 14:16:11 +08:00
if ( ! imageList . empty ( ) ) {
// Handling directory of images
if ( currentImageIndex > = imageList . size ( ) ) {
waitKey ( ) ;
break ; // Exit if all images are processed
}
frame = imread ( imageList [ currentImageIndex + + ] ) ;
if ( frame . empty ( ) ) {
cout < < " Cannot open file " < < endl ;
continue ;
}
} else {
// Handling video or single image
cap > > frame ;
}
2018-03-03 21:43:21 +08:00
if ( frame . empty ( ) )
{
break ;
}
2018-03-04 00:29:37 +08:00
//! [Create a 4D blob from a frame]
2021-01-26 19:06:15 +08:00
blobFromImage ( frame , blob , scale , Size ( inpWidth , inpHeight ) , mean , swapRB , crop ) ;
// Check std values.
if ( std . val [ 0 ] ! = 0.0 & & std . val [ 1 ] ! = 0.0 & & std . val [ 2 ] ! = 0.0 )
{
// Divide blob by std.
divide ( blob , std , blob ) ;
}
2018-03-04 00:29:37 +08:00
//! [Create a 4D blob from a frame]
//! [Set input blob]
2018-03-03 21:43:21 +08:00
net . setInput ( blob ) ;
2018-03-04 00:29:37 +08:00
//! [Set input blob]
2024-08-06 14:16:11 +08:00
TickMeter timeRecorder ;
2021-11-24 05:15:31 +08:00
timeRecorder . reset ( ) ;
Mat prob = net . forward ( ) ;
double t1 ;
2024-08-06 14:16:11 +08:00
//! [Make forward pass]
2021-11-24 05:15:31 +08:00
timeRecorder . start ( ) ;
prob = net . forward ( ) ;
timeRecorder . stop ( ) ;
2024-08-06 14:16:11 +08:00
//! [Make forward pass]
2021-11-24 05:15:31 +08:00
2024-08-06 14:16:11 +08:00
//! [Get a class with a highest score]
int N = ( int ) prob . total ( ) , K = std : : min ( 5 , N ) ;
std : : vector < std : : pair < float , int > > prob_vec ;
for ( int i = 0 ; i < N ; i + + ) {
prob_vec . push_back ( std : : make_pair ( - prob . at < float > ( i ) , i ) ) ;
2021-11-24 05:15:31 +08:00
}
2024-08-06 14:16:11 +08:00
std : : sort ( prob_vec . begin ( ) , prob_vec . end ( ) ) ;
2018-03-03 21:43:21 +08:00
2024-08-06 14:16:11 +08:00
//! [Get a class with a highest score]
t1 = timeRecorder . getTimeMilli ( ) ;
timeRecorder . reset ( ) ;
string label = format ( " Inference time: %.1f ms " , t1 ) ;
Mat subframe = frame ( Rect ( 0 , 0 , std : : min ( 1000 , frame . cols ) , std : : min ( 300 , frame . rows ) ) ) ;
subframe * = 0.3f ;
putText ( frame , label , Point ( 20 , 50 ) , Scalar ( 0 , 255 , 0 ) , sans , 25 , 800 ) ;
2018-03-03 21:43:21 +08:00
2024-08-06 14:16:11 +08:00
// Print predicted class.
for ( int i = 0 ; i < K ; i + + ) {
int classId = prob_vec [ i ] . second ;
float confidence = - prob_vec [ i ] . first ;
label = format ( " %d. %s: %.2f " , i + 1 , ( classes . empty ( ) ? format ( " Class #%d " , classId ) . c_str ( ) :
classes [ classId ] . c_str ( ) ) , confidence ) ;
putText ( frame , label , Point ( 20 , 110 + i * 35 ) , Scalar ( 0 , 255 , 0 ) , sans , 25 , 500 ) ;
}
2018-03-03 21:43:21 +08:00
imshow ( kWinName , frame ) ;
2024-08-06 14:16:11 +08:00
int key = waitKey ( isImgList ? 1000 : 100 ) ;
if ( key = = ' ' )
key = waitKey ( ) ;
if ( key = = ' q ' | | key = = 27 ) // Check if 'q' or 'ESC' is pressed
return 0 ;
2018-03-03 21:43:21 +08:00
}
2024-08-06 14:16:11 +08:00
waitKey ( ) ;
2018-03-03 21:43:21 +08:00
return 0 ;
}