diff --git a/samples/dnn/torch_enet.cpp b/samples/dnn/torch_enet.cpp index 4f9ad21378..6101d17f06 100644 --- a/samples/dnn/torch_enet.cpp +++ b/samples/dnn/torch_enet.cpp @@ -20,21 +20,35 @@ const String keys = "https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }" "{model m || path to Torch .net model file (model_best.net) }" "{image i || path to image file }" - "{c_names c || path to file with classnames for channels (optional, categories.txt) }" "{result r || path to save output blob (optional, binary format, NCHW order) }" "{show s || whether to show all output channels or not}" - "{o_blob || output blob's name. If empty, last blob's name in net is used}" - ; + "{o_blob || output blob's name. If empty, last blob's name in net is used}"; -static void colorizeSegmentation(const Mat &score, Mat &segm, - Mat &legend, vector &classNames, vector &colors); -static vector readColors(const String &filename, vector& classNames); +static const int kNumClasses = 20; + +static const String classes[] = { + "Background", "Road", "Sidewalk", "Building", "Wall", "Fence", "Pole", + "TrafficLight", "TrafficSign", "Vegetation", "Terrain", "Sky", "Person", + "Rider", "Car", "Truck", "Bus", "Train", "Motorcycle", "Bicycle" +}; + +static const Vec3b colors[] = { + Vec3b(0, 0, 0), Vec3b(244, 126, 205), Vec3b(254, 83, 132), Vec3b(192, 200, 189), + Vec3b(50, 56, 251), Vec3b(65, 199, 228), Vec3b(240, 178, 193), Vec3b(201, 67, 188), + Vec3b(85, 32, 33), Vec3b(116, 25, 18), Vec3b(162, 33, 72), Vec3b(101, 150, 210), + Vec3b(237, 19, 16), Vec3b(149, 197, 72), Vec3b(80, 182, 21), Vec3b(141, 5, 207), + Vec3b(189, 156, 39), Vec3b(235, 170, 186), Vec3b(133, 109, 144), Vec3b(231, 160, 96) +}; + +static void showLegend(); + +static void colorizeSegmentation(const Mat &score, Mat &segm); int main(int argc, char **argv) { CommandLineParser parser(argc, argv, keys); - if (parser.has("help")) + if (parser.has("help") || argc == 1) { parser.printMessage(); return 0; @@ -49,7 +63,6 @@ int main(int argc, char **argv) return 0; } - String classNamesFile = parser.get("c_names"); String resultFile = parser.get("result"); //! [Read model and initialize network] @@ -63,17 +76,11 @@ int main(int argc, char **argv) exit(-1); } - Size origSize = img.size(); - Size inputImgSize = cv::Size(1024, 512); - - if (inputImgSize != origSize) - resize(img, img, inputImgSize); //Resize image to input size - - Mat inputBlob = blobFromImage(img, 1./255); //Convert Mat to image batch + Mat inputBlob = blobFromImage(img, 1./255, Size(1024, 512), Scalar(), true, false); //Convert Mat to image batch //! [Prepare blob] //! [Set input blob] - net.setInput(inputBlob, ""); //set the network input + net.setInput(inputBlob); //set the network input //! [Set input blob] TickMeter tm; @@ -102,41 +109,47 @@ int main(int argc, char **argv) if (parser.has("show")) { - std::vector classNames; - vector colors; - if(!classNamesFile.empty()) { - colors = readColors(classNamesFile, classNames); - } - Mat segm, legend; - colorizeSegmentation(result, segm, legend, classNames, colors); + Mat segm, show; + colorizeSegmentation(result, segm); + showLegend(); - Mat show; + cv::resize(segm, segm, img.size(), 0, 0, cv::INTER_NEAREST); addWeighted(img, 0.1, segm, 0.9, 0.0, show); - cv::resize(show, show, origSize, 0, 0, cv::INTER_NEAREST); imshow("Result", show); - if(classNames.size()) - imshow("Legend", legend); waitKey(); } - return 0; } //main -static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector &classNames, vector &colors) +static void showLegend() +{ + static const int kBlockHeight = 30; + + cv::Mat legend(kBlockHeight * kNumClasses, 200, CV_8UC3); + for(int i = 0; i < kNumClasses; i++) + { + cv::Mat block = legend.rowRange(i * kBlockHeight, (i + 1) * kBlockHeight); + block.setTo(colors[i]); + putText(block, classes[i], Point(0, kBlockHeight / 2), FONT_HERSHEY_SIMPLEX, 0.5, Vec3b(255, 255, 255)); + } + imshow("Legend", legend); +} + +static void colorizeSegmentation(const Mat &score, Mat &segm) { const int rows = score.size[2]; const int cols = score.size[3]; const int chns = score.size[1]; - cv::Mat maxCl(rows, cols, CV_8UC1); - cv::Mat maxVal(rows, cols, CV_32FC1); - for (int ch = 0; ch < chns; ch++) + Mat maxCl = Mat::zeros(rows, cols, CV_8UC1); + Mat maxVal(rows, cols, CV_32FC1, score.data); + for (int ch = 1; ch < chns; ch++) { for (int row = 0; row < rows; row++) { const float *ptrScore = score.ptr(0, ch, row); - uchar *ptrMaxCl = maxCl.ptr(row); + uint8_t *ptrMaxCl = maxCl.ptr(row); float *ptrMaxVal = maxVal.ptr(row); for (int col = 0; col < cols; col++) { @@ -153,57 +166,10 @@ static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vecto for (int row = 0; row < rows; row++) { const uchar *ptrMaxCl = maxCl.ptr(row); - cv::Vec3b *ptrSegm = segm.ptr(row); + Vec3b *ptrSegm = segm.ptr(row); for (int col = 0; col < cols; col++) { ptrSegm[col] = colors[ptrMaxCl[col]]; } } - - if (classNames.size() == colors.size()) - { - int blockHeight = 30; - legend.create(blockHeight*(int)classNames.size(), 200, CV_8UC3); - for(int i = 0; i < (int)classNames.size(); i++) - { - cv::Mat block = legend.rowRange(i*blockHeight, (i+1)*blockHeight); - block = colors[i]; - putText(block, classNames[i], Point(0, blockHeight/2), FONT_HERSHEY_SIMPLEX, 0.5, Scalar()); - } - } -} - -static vector readColors(const String &filename, vector& classNames) -{ - vector colors; - classNames.clear(); - - ifstream fp(filename.c_str()); - if (!fp.is_open()) - { - cerr << "File with colors not found: " << filename << endl; - exit(-1); - } - - string line; - while (!fp.eof()) - { - getline(fp, line); - if (line.length()) - { - stringstream ss(line); - - string name; ss >> name; - int temp; - cv::Vec3b color; - ss >> temp; color[0] = (uchar)temp; - ss >> temp; color[1] = (uchar)temp; - ss >> temp; color[2] = (uchar)temp; - classNames.push_back(name); - colors.push_back(color); - } - } - - fp.close(); - return colors; }