2018-03-07 00:29:23 +08:00
# include <fstream>
# include <sstream>
2024-05-15 14:39:34 +08:00
# include <iostream>
2018-03-07 00:29:23 +08:00
# include <opencv2/dnn.hpp>
# include <opencv2/imgproc.hpp>
# include <opencv2/highgui.hpp>
2018-09-20 22:59:04 +08:00
# include "common.hpp"
2018-03-07 00:29:23 +08:00
using namespace cv ;
2024-05-15 14:39:34 +08:00
using namespace std ;
2018-03-07 00:29:23 +08:00
using namespace dnn ;
2024-11-08 13:55:46 +08:00
const string about =
" Use this script to run semantic segmentation deep learning networks using OpenCV. \n \n "
" Firstly, download required models using `download_models.py` (if not already done). Set environment variable OPENCV_DOWNLOAD_CACHE_DIR to specify where models should be downloaded. Also, point OPENCV_SAMPLES_DATA_PATH to opencv/samples/data. \n "
" To run: \n "
" \t ./example_dnn_classification modelName(e.g. u2netp) --input=$OPENCV_SAMPLES_DATA_PATH/butterfly.jpg (or ignore this argument to use device camera) \n "
" Model path can also be specified using --model argument. " ;
2024-05-15 14:39:34 +08:00
const string param_keys =
2024-11-08 13:55:46 +08:00
" { help h | | Print help message. } "
" { @alias | | An alias name of model to extract preprocessing parameters from models.yml file. } "
" { zoo | ../dnn/models.yml | An optional path to file with preprocessing parameters } "
" { device | 0 | camera device number. } "
" { input i | | Path to input image or video file. Skip this argument to capture frames from a camera. } "
" { colors | | Optional path to a text file with colors for an every class. "
2024-05-15 14:39:34 +08:00
" Every color is represented with three values from 0 to 255 in BGR channels order. } " ;
const string backend_keys = format (
2024-11-08 13:55:46 +08:00
" { backend | default | Choose one of computation backends: "
" default: automatically (by default), "
" openvino: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
" opencv: OpenCV implementation, "
" vkcom: VKCOM, "
" cuda: CUDA, "
" webnn: WebNN } " ) ;
2024-05-15 14:39:34 +08:00
const string target_keys = format (
2024-11-08 13:55:46 +08:00
" { target | cpu | Choose one of target computation devices: "
" cpu: CPU target (by default), "
" opencl: OpenCL, "
" opencl_fp16: OpenCL fp16 (half-float precision), "
" vpu: VPU, "
" vulkan: Vulkan, "
" cuda: CUDA, "
" cuda_fp16: CUDA fp16 (half-float preprocess) } " ) ;
2024-05-15 14:39:34 +08:00
string keys = param_keys + backend_keys + target_keys ;
2024-11-08 13:55:46 +08:00
vector < string > labels ;
2024-05-15 14:39:34 +08:00
vector < Vec3b > colors ;
2018-03-07 00:29:23 +08:00
2024-11-08 13:55:46 +08:00
static void colorizeSegmentation ( const Mat & score , Mat & segm )
{
const int rows = score . size [ 2 ] ;
const int cols = score . size [ 3 ] ;
const int chns = score . size [ 1 ] ;
if ( colors . empty ( ) )
{
// Generate colors.
colors . push_back ( Vec3b ( ) ) ;
for ( int i = 1 ; i < chns ; + + i )
{
Vec3b color ;
for ( int j = 0 ; j < 3 ; + + j )
color [ j ] = ( colors [ i - 1 ] [ j ] + rand ( ) % 256 ) / 2 ;
colors . push_back ( color ) ;
}
}
else if ( chns ! = ( int ) colors . size ( ) )
{
CV_Error ( Error : : StsError , format ( " Number of output labels does not match "
" number of colors (%d != %zu) " ,
chns , colors . size ( ) ) ) ;
}
Mat maxCl = Mat : : zeros ( rows , cols , CV_8UC1 ) ;
Mat maxVal ( rows , cols , CV_32FC1 , score . data ) ;
for ( int ch = 1 ; ch < chns ; ch + + )
{
for ( int row = 0 ; row < rows ; row + + )
{
const float * ptrScore = score . ptr < float > ( 0 , ch , row ) ;
uint8_t * ptrMaxCl = maxCl . ptr < uint8_t > ( row ) ;
float * ptrMaxVal = maxVal . ptr < float > ( row ) ;
for ( int col = 0 ; col < cols ; col + + )
{
if ( ptrScore [ col ] > ptrMaxVal [ col ] )
{
ptrMaxVal [ col ] = ptrScore [ col ] ;
ptrMaxCl [ col ] = ( uchar ) ch ;
}
}
}
}
segm . create ( rows , cols , CV_8UC3 ) ;
for ( int row = 0 ; row < rows ; row + + )
{
const uchar * ptrMaxCl = maxCl . ptr < uchar > ( row ) ;
Vec3b * ptrSegm = segm . ptr < Vec3b > ( row ) ;
for ( int col = 0 ; col < cols ; col + + )
{
ptrSegm [ col ] = colors [ ptrMaxCl [ col ] ] ;
}
}
}
static void showLegend ( FontFace fontFace )
{
static const int kBlockHeight = 30 ;
static Mat legend ;
if ( legend . empty ( ) )
{
const int numClasses = ( int ) labels . size ( ) ;
if ( ( int ) colors . size ( ) ! = numClasses )
{
CV_Error ( Error : : StsError , format ( " Number of output labels does not match "
" number of labels (%zu != %zu) " ,
colors . size ( ) , labels . size ( ) ) ) ;
}
legend . create ( kBlockHeight * numClasses , 200 , CV_8UC3 ) ;
for ( int i = 0 ; i < numClasses ; i + + )
{
Mat block = legend . rowRange ( i * kBlockHeight , ( i + 1 ) * kBlockHeight ) ;
block . setTo ( colors [ i ] ) ;
Rect r = getTextSize ( Size ( ) , labels [ i ] , Point ( ) , fontFace , 15 , 400 ) ;
r . height + = 15 ; // padding
r . width + = 10 ; // padding
rectangle ( block , r , Scalar : : all ( 255 ) , FILLED ) ;
putText ( block , labels [ i ] , Point ( 10 , kBlockHeight / 2 ) , Scalar ( 0 , 0 , 0 ) , fontFace , 15 , 400 ) ;
}
namedWindow ( " Legend " , WINDOW_AUTOSIZE ) ;
imshow ( " Legend " , legend ) ;
}
}
2018-03-07 00:29:23 +08:00
2024-05-15 14:39:34 +08:00
int main ( int argc , char * * argv )
2018-03-07 00:29:23 +08:00
{
CommandLineParser parser ( argc , argv , keys ) ;
2018-09-20 22:59:04 +08:00
2024-05-15 14:39:34 +08:00
const string modelName = parser . get < String > ( " @alias " ) ;
2024-11-08 13:55:46 +08:00
const string zooFile = findFile ( parser . get < String > ( " zoo " ) ) ;
2018-09-20 22:59:04 +08:00
keys + = genPreprocArguments ( modelName , zooFile ) ;
parser = CommandLineParser ( argc , argv , keys ) ;
2024-11-08 13:55:46 +08:00
parser . about ( about ) ;
if ( ! parser . has ( " @alias " ) | | parser . has ( " help " ) )
2018-03-07 00:29:23 +08:00
{
parser . printMessage ( ) ;
return 0 ;
}
2024-11-08 13:55:46 +08:00
string sha1 = parser . get < String > ( " sha1 " ) ;
2018-03-07 00:29:23 +08:00
float scale = parser . get < float > ( " scale " ) ;
Scalar mean = parser . get < Scalar > ( " mean " ) ;
bool swapRB = parser . get < bool > ( " rgb " ) ;
int inpWidth = parser . get < int > ( " width " ) ;
int inpHeight = parser . get < int > ( " height " ) ;
2024-11-08 13:55:46 +08:00
String model = findModel ( parser . get < String > ( " model " ) , sha1 ) ;
const string backend = parser . get < String > ( " backend " ) ;
const string target = parser . get < String > ( " target " ) ;
int stdSize = 20 ;
int stdWeight = 400 ;
int stdImgSize = 512 ;
int imgWidth = - 1 ; // Initialization
int fontSize = 50 ;
int fontWeight = 500 ;
FontFace fontFace ( " sans " ) ;
2018-03-07 00:29:23 +08:00
2024-11-08 13:55:46 +08:00
// Open file with labels names.
if ( parser . has ( " labels " ) )
2018-03-07 00:29:23 +08:00
{
2024-11-08 13:55:46 +08:00
string file = findFile ( parser . get < String > ( " labels " ) ) ;
2024-05-15 14:39:34 +08:00
ifstream ifs ( file . c_str ( ) ) ;
2018-03-07 00:29:23 +08:00
if ( ! ifs . is_open ( ) )
CV_Error ( Error : : StsError , " File " + file + " not found " ) ;
2024-05-15 14:39:34 +08:00
string line ;
while ( getline ( ifs , line ) )
2018-03-07 00:29:23 +08:00
{
2024-11-08 13:55:46 +08:00
labels . push_back ( line ) ;
2018-03-07 00:29:23 +08:00
}
}
// Open file with colors.
if ( parser . has ( " colors " ) )
{
2024-07-03 19:03:12 +08:00
string file = findFile ( parser . get < String > ( " colors " ) ) ;
2024-05-15 14:39:34 +08:00
ifstream ifs ( file . c_str ( ) ) ;
2018-03-07 00:29:23 +08:00
if ( ! ifs . is_open ( ) )
CV_Error ( Error : : StsError , " File " + file + " not found " ) ;
2024-05-15 14:39:34 +08:00
string line ;
while ( getline ( ifs , line ) )
2018-03-07 00:29:23 +08:00
{
2024-05-15 14:39:34 +08:00
istringstream colorStr ( line . c_str ( ) ) ;
2018-03-07 00:29:23 +08:00
Vec3b color ;
for ( int i = 0 ; i < 3 & & ! colorStr . eof ( ) ; + + i )
colorStr > > color [ i ] ;
colors . push_back ( color ) ;
}
}
2018-08-15 19:55:47 +08:00
if ( ! parser . check ( ) )
{
parser . printErrors ( ) ;
return 1 ;
}
CV_Assert ( ! model . empty ( ) ) ;
2018-03-07 00:29:23 +08:00
//! [Read and initialize network]
2024-11-08 13:55:46 +08:00
EngineType engine = ENGINE_AUTO ;
if ( backend ! = " default " | | target ! = " cpu " ) {
engine = ENGINE_CLASSIC ;
}
Net net = readNetFromONNX ( model , engine ) ;
net . setPreferableBackend ( getBackendID ( backend ) ) ;
net . setPreferableTarget ( getTargetID ( target ) ) ;
2018-03-07 00:29:23 +08:00
//! [Read and initialize network]
// Create a window
2024-05-15 14:39:34 +08:00
static const string kWinName = " Deep learning semantic segmentation in OpenCV " ;
2024-11-08 13:55:46 +08:00
namedWindow ( kWinName , WINDOW_AUTOSIZE ) ;
2018-03-07 00:29:23 +08:00
//! [Open a video file or an image file or a camera stream]
VideoCapture cap ;
if ( parser . has ( " input " ) )
2024-05-15 14:39:34 +08:00
cap . open ( findFile ( parser . get < String > ( " input " ) ) ) ;
2018-03-07 00:29:23 +08:00
else
2018-05-08 12:07:23 +08:00
cap . open ( parser . get < int > ( " device " ) ) ;
2024-11-18 22:17:05 +08:00
if ( ! cap . isOpened ( ) ) {
cerr < < " Error: Video could not be opened. " < < endl ;
return - 1 ;
}
2018-03-07 00:29:23 +08:00
//! [Open a video file or an image file or a camera stream]
// Process frames.
Mat frame , blob ;
while ( waitKey ( 1 ) < 0 )
{
cap > > frame ;
if ( frame . empty ( ) )
{
waitKey ( ) ;
break ;
}
2024-11-08 13:55:46 +08:00
if ( imgWidth = = - 1 ) {
imgWidth = max ( frame . rows , frame . cols ) ;
fontSize = min ( fontSize , ( stdSize * imgWidth ) / stdImgSize ) ;
fontWeight = min ( fontWeight , ( stdWeight * imgWidth ) / stdImgSize ) ;
}
2024-05-15 14:39:34 +08:00
imshow ( " Original Image " , frame ) ;
2018-03-07 00:29:23 +08:00
//! [Create a 4D blob from a frame]
blobFromImage ( frame , blob , scale , Size ( inpWidth , inpHeight ) , mean , swapRB , false ) ;
//! [Set input blob]
net . setInput ( blob ) ;
2024-07-03 19:03:12 +08:00
//! [Set input blob]
2024-05-15 14:39:34 +08:00
if ( modelName = = " u2netp " )
{
2024-07-03 19:03:12 +08:00
vector < Mat > output ;
net . forward ( output , net . getUnconnectedOutLayersNames ( ) ) ;
Mat pred = output [ 0 ] . reshape ( 1 , output [ 0 ] . size [ 2 ] ) ;
pred . convertTo ( pred , CV_8U , 255.0 ) ;
Mat mask ;
resize ( pred , mask , Size ( frame . cols , frame . rows ) , 0 , 0 , INTER_AREA ) ;
2024-05-15 14:39:34 +08:00
// Create overlays for foreground and background
2024-07-03 19:03:12 +08:00
Mat foreground_overlay ;
// Set foreground (object) to red
Mat all_zeros = Mat : : zeros ( frame . size ( ) , CV_8UC1 ) ;
vector < Mat > channels = { all_zeros , all_zeros , mask } ;
merge ( channels , foreground_overlay ) ;
2024-05-15 14:39:34 +08:00
// Blend the overlays with the original frame
2024-07-03 19:03:12 +08:00
addWeighted ( frame , 0.25 , foreground_overlay , 0.75 , 0 , frame ) ;
2024-05-15 14:39:34 +08:00
}
else
{
2024-07-03 19:03:12 +08:00
//! [Make forward pass]
Mat score = net . forward ( ) ;
//! [Make forward pass]
2024-05-15 14:39:34 +08:00
Mat segm ;
colorizeSegmentation ( score , segm ) ;
resize ( segm , segm , frame . size ( ) , 0 , 0 , INTER_NEAREST ) ;
addWeighted ( frame , 0.1 , segm , 0.9 , 0.0 , frame ) ;
}
2018-03-07 00:29:23 +08:00
// Put efficiency information.
2024-05-15 14:39:34 +08:00
vector < double > layersTimes ;
2018-03-07 00:29:23 +08:00
double freq = getTickFrequency ( ) / 1000 ;
double t = net . getPerfProfile ( layersTimes ) / freq ;
2024-05-15 14:39:34 +08:00
string label = format ( " Inference time: %.2f ms " , t ) ;
2024-11-08 13:55:46 +08:00
Rect r = getTextSize ( Size ( ) , label , Point ( ) , fontFace , fontSize , fontWeight ) ;
r . height + = fontSize ; // padding
r . width + = 10 ; // padding
rectangle ( frame , r , Scalar : : all ( 255 ) , FILLED ) ;
putText ( frame , label , Point ( 10 , fontSize ) , Scalar ( 0 , 0 , 0 ) , fontFace , fontSize , fontWeight ) ;
2018-03-07 00:29:23 +08:00
imshow ( kWinName , frame ) ;
2024-11-08 13:55:46 +08:00
if ( ! labels . empty ( ) )
showLegend ( fontFace ) ;
2018-03-07 00:29:23 +08:00
}
return 0 ;
}