/* Sample of using OpenCV dnn module with Torch ENet model. */ #include #include #include using namespace cv; using namespace cv::dnn; #include #include #include #include using namespace std; const String keys = "{help h || Sample app for loading ENet Torch model. " "The model and class names list can be downloaded here: " "https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }" "{model m || path to Torch .net model file (model_best.net) }" "{image i || path to image file }" "{result r || path to save output blob (optional, binary format, NCHW order) }" "{show s || whether to show all output channels or not}" "{o_blob || output blob's name. If empty, last blob's name in net is used}"; static const int kNumClasses = 20; static const String classes[] = { "Background", "Road", "Sidewalk", "Building", "Wall", "Fence", "Pole", "TrafficLight", "TrafficSign", "Vegetation", "Terrain", "Sky", "Person", "Rider", "Car", "Truck", "Bus", "Train", "Motorcycle", "Bicycle" }; static const Vec3b colors[] = { Vec3b(0, 0, 0), Vec3b(244, 126, 205), Vec3b(254, 83, 132), Vec3b(192, 200, 189), Vec3b(50, 56, 251), Vec3b(65, 199, 228), Vec3b(240, 178, 193), Vec3b(201, 67, 188), Vec3b(85, 32, 33), Vec3b(116, 25, 18), Vec3b(162, 33, 72), Vec3b(101, 150, 210), Vec3b(237, 19, 16), Vec3b(149, 197, 72), Vec3b(80, 182, 21), Vec3b(141, 5, 207), Vec3b(189, 156, 39), Vec3b(235, 170, 186), Vec3b(133, 109, 144), Vec3b(231, 160, 96) }; static void showLegend(); static void colorizeSegmentation(const Mat &score, Mat &segm); int main(int argc, char **argv) { CommandLineParser parser(argc, argv, keys); if (parser.has("help") || argc == 1) { parser.printMessage(); return 0; } String modelFile = parser.get("model"); String imageFile = parser.get("image"); if (!parser.check()) { parser.printErrors(); return 0; } String resultFile = parser.get("result"); //! [Read model and initialize network] dnn::Net net = dnn::readNetFromTorch(modelFile); //! [Prepare blob] Mat img = imread(imageFile), input; if (img.empty()) { std::cerr << "Can't read image from the file: " << imageFile << std::endl; exit(-1); } Mat inputBlob = blobFromImage(img, 1./255, Size(1024, 512), Scalar(), true, false); //Convert Mat to image batch //! [Prepare blob] //! [Set input blob] net.setInput(inputBlob); //set the network input //! [Set input blob] TickMeter tm; String oBlob = net.getLayerNames().back(); if (!parser.get("o_blob").empty()) { oBlob = parser.get("o_blob"); } //! [Make forward pass] tm.start(); Mat result = net.forward(oBlob); tm.stop(); if (!resultFile.empty()) { CV_Assert(result.isContinuous()); ofstream fout(resultFile.c_str(), ios::out | ios::binary); fout.write((char*)result.data, result.total() * sizeof(float)); fout.close(); } std::cout << "Output blob: " << result.size[0] << " x " << result.size[1] << " x " << result.size[2] << " x " << result.size[3] << "\n"; std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl; if (parser.has("show")) { Mat segm, show; colorizeSegmentation(result, segm); showLegend(); cv::resize(segm, segm, img.size(), 0, 0, cv::INTER_NEAREST); addWeighted(img, 0.1, segm, 0.9, 0.0, show); imshow("Result", show); waitKey(); } return 0; } //main static void showLegend() { static const int kBlockHeight = 30; cv::Mat legend(kBlockHeight * kNumClasses, 200, CV_8UC3); for(int i = 0; i < kNumClasses; i++) { cv::Mat block = legend.rowRange(i * kBlockHeight, (i + 1) * kBlockHeight); block.setTo(colors[i]); putText(block, classes[i], Point(0, kBlockHeight / 2), FONT_HERSHEY_SIMPLEX, 0.5, Vec3b(255, 255, 255)); } imshow("Legend", legend); } static void colorizeSegmentation(const Mat &score, Mat &segm) { const int rows = score.size[2]; const int cols = score.size[3]; const int chns = score.size[1]; Mat maxCl = Mat::zeros(rows, cols, CV_8UC1); Mat maxVal(rows, cols, CV_32FC1, score.data); for (int ch = 1; ch < chns; ch++) { for (int row = 0; row < rows; row++) { const float *ptrScore = score.ptr(0, ch, row); uint8_t *ptrMaxCl = maxCl.ptr(row); float *ptrMaxVal = maxVal.ptr(row); for (int col = 0; col < cols; col++) { if (ptrScore[col] > ptrMaxVal[col]) { ptrMaxVal[col] = ptrScore[col]; ptrMaxCl[col] = (uchar)ch; } } } } segm.create(rows, cols, CV_8UC3); for (int row = 0; row < rows; row++) { const uchar *ptrMaxCl = maxCl.ptr(row); Vec3b *ptrSegm = segm.ptr(row); for (int col = 0; col < cols; col++) { ptrSegm[col] = colors[ptrMaxCl[col]]; } } }