Merge pull request #25433 from gursimarsingh:colorization_onnx_sample

Replaced caffe model with onnx for colorization sample #25433

#25006

Improved sample for colorization with onnx model in cpp and python. Added a demo image in data folder for testing

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Gursimar Singh 2024-04-18 20:45:05 +05:30 committed by GitHub
parent b009a63e6b
commit 448375d1e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 141 additions and 133 deletions

View File

@ -1,128 +1,117 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html
// To download the onnx model, see: https://storage.googleapis.com/ailia-models/colorization/colorizer.onnx
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/imgcodecs.hpp>
#include "common.hpp"
#include <opencv2/highgui.hpp>
#include <iostream>
using namespace cv;
using namespace cv::dnn;
using namespace std;
using namespace cv::dnn;
// the 313 ab cluster centers from pts_in_hull.npy (already transposed)
static float hull_pts[] = {
-90., -90., -90., -90., -90., -80., -80., -80., -80., -80., -80., -80., -80., -70., -70., -70., -70., -70., -70., -70., -70.,
-70., -70., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -50., -50., -50., -50., -50., -50., -50., -50.,
-50., -50., -50., -50., -50., -50., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -30.,
-30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -20., -20., -20., -20., -20., -20., -20.,
-20., -20., -20., -20., -20., -20., -20., -20., -20., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10.,
10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20.,
20., 20., 20., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 40., 40., 40., 40.,
40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
50., 50., 50., 50., 50., 50., 50., 50., 50., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60.,
60., 60., 60., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 80., 80., 80.,
80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 90., 90., 90., 90., 90., 90., 90., 90., 90., 90.,
90., 90., 90., 90., 90., 90., 90., 90., 90., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 50., 60., 70., 80., 90.,
20., 30., 40., 50., 60., 70., 80., 90., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -20., -10., 0., 10., 20., 30., 40., 50.,
60., 70., 80., 90., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -40., -30., -20., -10., 0., 10., 20.,
30., 40., 50., 60., 70., 80., 90., 100., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -50.,
-40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -60., -50., -40., -30., -20., -10., 0., 10., 20.,
30., 40., 50., 60., 70., 80., 90., 100., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.,
100., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -80., -70., -60., -50.,
-40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -90., -80., -70., -60., -50., -40., -30., -20., -10.,
0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30.,
40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70.,
80., -110., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100.,
-90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100., -90., -80., -70.,
-60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -110., -100., -90., -80., -70., -60., -50., -40., -30.,
-20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0.
};
int main(int argc, char **argv)
{
int main(int argc, char** argv) {
const string about =
"This sample demonstrates recoloring grayscale images with dnn.\n"
"This program is based on:\n"
" http://richzhang.github.io/colorization\n"
" https://github.com/richzhang/colorization\n"
"Download caffemodel and prototxt files:\n"
" http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2.caffemodel\n"
" https://raw.githubusercontent.com/richzhang/colorization/caffe/models/colorization_deploy_v2.prototxt\n";
const string keys =
"{ h help | | print this help message }"
"{ proto | colorization_deploy_v2.prototxt | model configuration }"
"{ model | colorization_release_v2.caffemodel | model weights }"
"{ image | space_shuttle.jpg | path to image file }"
"{ opencl | | enable OpenCL }";
"To download the onnx model:\n"
" https://storage.googleapis.com/ailia-models/colorization/colorizer.onnx\n";
const string param_keys =
"{ help h | | Print help message. }"
"{ input i | baboon.jpg | Path to the input image }"
"{ onnx_model_path | | Path to the ONNX model. Required. }";
const string backend_keys = format(
"{ backend | 0 | Choose one of computation backends: "
"%d: automatically (by default), "
"%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
"%d: OpenCV implementation, "
"%d: VKCOM, "
"%d: CUDA, "
"%d: WebNN }",
cv::dnn::DNN_BACKEND_DEFAULT, cv::dnn::DNN_BACKEND_INFERENCE_ENGINE, cv::dnn::DNN_BACKEND_OPENCV,
cv::dnn::DNN_BACKEND_VKCOM, cv::dnn::DNN_BACKEND_CUDA, cv::dnn::DNN_BACKEND_WEBNN);
const string target_keys = format(
"{ target | 0 | Choose one of target computation devices: "
"%d: CPU target (by default), "
"%d: OpenCL, "
"%d: OpenCL fp16 (half-float precision), "
"%d: VPU, "
"%d: Vulkan, "
"%d: CUDA, "
"%d: CUDA fp16 (half-float preprocess) }",
cv::dnn::DNN_TARGET_CPU, cv::dnn::DNN_TARGET_OPENCL, cv::dnn::DNN_TARGET_OPENCL_FP16,
cv::dnn::DNN_TARGET_MYRIAD, cv::dnn::DNN_TARGET_VULKAN, cv::dnn::DNN_TARGET_CUDA,
cv::dnn::DNN_TARGET_CUDA_FP16);
const string keys = param_keys + backend_keys + target_keys;
CommandLineParser parser(argc, argv, keys);
parser.about(about);
if (parser.has("help"))
{
if (parser.has("help")) {
parser.printMessage();
return 0;
}
string modelTxt = samples::findFile(parser.get<string>("proto"));
string modelBin = samples::findFile(parser.get<string>("model"));
string imageFile = samples::findFile(parser.get<string>("image"));
bool useOpenCL = parser.has("opencl");
if (!parser.check())
{
parser.printErrors();
return 1;
string inputImagePath = parser.get<string>("input");
string onnxModelPath = parser.get<string>("onnx_model_path");
int backendId = parser.get<int>("backend");
int targetId = parser.get<int>("target");
if (onnxModelPath.empty()) {
cerr << "The path to the ONNX model is required!" << endl;
return -1;
}
Mat img = imread(imageFile);
if (img.empty())
{
cout << "Can't read image from file: " << imageFile << endl;
return 2;
Mat imgGray = imread(samples::findFile(inputImagePath), IMREAD_GRAYSCALE);
if (imgGray.empty()) {
cerr << "Could not read the image: " << inputImagePath << endl;
return -1;
}
// fixed input size for the pretrained network
const int W_in = 224;
const int H_in = 224;
Net net = dnn::readNetFromCaffe(modelTxt, modelBin);
if (useOpenCL)
net.setPreferableTarget(DNN_TARGET_OPENCL);
Mat imgL = imgGray;
imgL.convertTo(imgL, CV_32F, 100.0/255.0);
Mat imgLResized;
resize(imgL, imgLResized, Size(256, 256), 0, 0, INTER_CUBIC);
// setup additional layers:
int sz[] = {2, 313, 1, 1};
const Mat pts_in_hull(4, sz, CV_32F, hull_pts);
Ptr<dnn::Layer> class8_ab = net.getLayer("class8_ab");
class8_ab->blobs.push_back(pts_in_hull);
Ptr<dnn::Layer> conv8_313_rh = net.getLayer("conv8_313_rh");
conv8_313_rh->blobs.push_back(Mat(1, 313, CV_32F, Scalar(2.606)));
// Prepare the model
dnn::Net net = dnn::readNetFromONNX(onnxModelPath);
net.setPreferableBackend(backendId);
net.setPreferableTarget(targetId);
//! [Read and initialize network]
// extract L channel and subtract mean
Mat lab, L, input;
img.convertTo(img, CV_32F, 1.0/255);
cvtColor(img, lab, COLOR_BGR2Lab);
extractChannel(lab, L, 0);
resize(L, input, Size(W_in, H_in));
input -= 50;
// Create blob from the image
Mat blob = dnn::blobFromImage(imgLResized, 1.0, Size(256, 256), Scalar(), false, false);
// run the L channel through the network
Mat inputBlob = blobFromImage(input);
net.setInput(inputBlob);
net.setInput(blob);
// Run inference
Mat result = net.forward();
// retrieve the calculated a,b channels from the network output
Size siz(result.size[2], result.size[3]);
Mat a = Mat(siz, CV_32F, result.ptr(0,0));
Mat b = Mat(siz, CV_32F, result.ptr(0,1));
resize(a, a, img.size());
resize(b, b, img.size());
Mat a(siz, CV_32F, result.ptr(0,0));
Mat b(siz, CV_32F, result.ptr(0,1));
resize(a, a, imgGray.size());
resize(b, b, imgGray.size());
// merge, and convert back to BGR
Mat color, chn[] = {L, a, b};
Mat color, chn[] = {imgL, a, b};
// Proc
Mat lab;
merge(chn, 3, lab);
cvtColor(lab, color, COLOR_Lab2BGR);
imshow("color", color);
imshow("original", img);
imshow("input image", imgGray);
imshow("output image", color);
waitKey();
return 0;
}

View File

@ -1,69 +1,88 @@
# Script is based on https://github.com/richzhang/colorization/blob/master/colorization/colorize.py
# To download the caffemodel and the prototxt, see: https://github.com/richzhang/colorization/tree/caffe/colorization/models
# To download pts_in_hull.npy, see: https://github.com/richzhang/colorization/tree/caffe/colorization/resources/pts_in_hull.npy
# To download the onnx model, see: https://storage.googleapis.com/ailia-models/colorization/colorizer.onnx
# python colorization.py --onnx_model_path colorizer.onnx --input ansel_adams3.jpg
import numpy as np
import argparse
import cv2 as cv
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description='iColor: deep interactive colorization')
parser.add_argument('--input', help='Path to image or video. Skip to capture frames from camera')
parser.add_argument('--prototxt', help='Path to colorization_deploy_v2.prototxt', required=True)
parser.add_argument('--caffemodel', help='Path to colorization_release_v2.caffemodel', required=True)
parser.add_argument('--kernel', help='Path to pts_in_hull.npy', required=True)
backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE,
cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_VKCOM, cv.dnn.DNN_BACKEND_CUDA)
targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD,
cv.dnn.DNN_TARGET_HDDL, cv.dnn.DNN_TARGET_VULKAN, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16)
parser = argparse.ArgumentParser(description='iColor: deep interactive colorization')
parser.add_argument('--input', default='baboon.jpg',help='Path to image.')
parser.add_argument('--onnx_model_path', help='Path to onnx model', required=True)
parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
help="Choose one of computation backends: "
"%d: automatically (by default), "
"%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
"%d: OpenCV implementation, "
"%d: VKCOM, "
"%d: CUDA" % backends)
parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
help='Choose one of target computation devices: '
'%d: CPU target (by default), '
'%d: OpenCL, '
'%d: OpenCL fp16 (half-float precision), '
'%d: NCS2 VPU, '
'%d: HDDL VPU, '
'%d: Vulkan, '
'%d: CUDA, '
'%d: CUDA fp16 (half-float preprocess)'% targets)
args = parser.parse_args()
return args
if __name__ == '__main__':
W_in = 224
H_in = 224
imshowSize = (640, 480)
args = parse_args()
img_gray=cv.imread(cv.samples.findFile(args.input),cv.IMREAD_GRAYSCALE)
# Select desired model
net = cv.dnn.readNetFromCaffe(args.prototxt, args.caffemodel)
img_gray_rs = cv.resize(img_gray, (256, 256), interpolation=cv.INTER_CUBIC)
img_gray_rs = img_gray_rs.astype(np.float32) # Convert to float to avoid data overflow
img_gray_rs *= (100.0 / 255.0) # Scale L channel to 0-100 range
pts_in_hull = np.load(args.kernel) # load cluster centers
onnx_model_path = args.onnx_model_path # Update this path to your ONNX model's path
session = cv.dnn.readNetFromONNX(onnx_model_path)
session.setPreferableBackend(args.backend)
session.setPreferableTarget(args.target)
# populate cluster centers as 1x1 convolution kernel
pts_in_hull = pts_in_hull.transpose().reshape(2, 313, 1, 1)
net.getLayer(net.getLayerId('class8_ab')).blobs = [pts_in_hull.astype(np.float32)]
net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, np.float32)]
# Process each image in the batch (assuming batch processing is needed)
blob = cv.dnn.blobFromImage(img_gray_rs, swapRB=False) # Adjust swapRB according to your model's training
session.setInput(blob)
result_numpy = np.array(session.forward()[0])
if args.input:
cap = cv.VideoCapture(args.input)
if result_numpy.shape[0] == 2:
# Transpose result_numpy to shape (H, W, 2)
ab = result_numpy.transpose((1, 2, 0))
else:
cap = cv.VideoCapture(0)
# If it's already (H, W, 2), assign it directly
ab = result_numpy
while cv.waitKey(1) < 0:
hasFrame, frame = cap.read()
if not hasFrame:
cv.waitKey()
break
img_rgb = (frame[:,:,[2, 1, 0]] * 1.0 / 255).astype(np.float32)
# Resize ab to match img_gray's dimensions if they are not the same
h, w = img_gray.shape
if ab.shape[:2] != (h, w):
ab_resized = cv.resize(ab, (w, h), interpolation=cv.INTER_LINEAR)
else:
ab_resized = ab
img_lab = cv.cvtColor(img_rgb, cv.COLOR_RGB2Lab)
img_l = img_lab[:,:,0] # pull out L channel
(H_orig,W_orig) = img_rgb.shape[:2] # original image size
# Expand dimensions of L to match ab's dimensions
img_l_expanded = np.expand_dims(img_gray, axis=-1)
# resize image to network input size
img_rs = cv.resize(img_rgb, (W_in, H_in)) # resize image to network input size
img_lab_rs = cv.cvtColor(img_rs, cv.COLOR_RGB2Lab)
img_l_rs = img_lab_rs[:,:,0]
img_l_rs -= 50 # subtract 50 for mean-centering
# Concatenate L with AB to get the LAB image
lab_image = np.concatenate((img_l_expanded, ab_resized), axis=-1)
net.setInput(cv.dnn.blobFromImage(img_l_rs))
ab_dec = net.forward()[0,:,:,:].transpose((1,2,0)) # this is our result
# Convert the Lab image to a 32-bit float format
lab_image = lab_image.astype(np.float32)
(H_out,W_out) = ab_dec.shape[:2]
ab_dec_us = cv.resize(ab_dec, (W_orig, H_orig))
img_lab_out = np.concatenate((img_l[:,:,np.newaxis],ab_dec_us),axis=2) # concatenate with original image L
img_bgr_out = np.clip(cv.cvtColor(img_lab_out, cv.COLOR_Lab2BGR), 0, 1)
# Normalize L channel to the range [0, 100] and AB channels to the range [-127, 127]
lab_image[:, :, 0] *= (100.0 / 255.0) # Rescale L channel
#lab_image[:, :, 1:] -= 128 # Shift AB channels
frame = cv.resize(frame, imshowSize)
cv.imshow('origin', frame)
cv.imshow('gray', cv.cvtColor(frame, cv.COLOR_RGB2GRAY))
cv.imshow('colorized', cv.resize(img_bgr_out, imshowSize))
# Convert the LAB image to BGR
image_bgr_out = cv.cvtColor(lab_image, cv.COLOR_Lab2BGR)
cv.imshow("input image",img_gray)
cv.imshow("output image",image_bgr_out)
cv.waitKey(0)