mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 11:40:44 +08:00
Merge pull request #25433 from gursimarsingh:colorization_onnx_sample
Replaced caffe model with onnx for colorization sample #25433 #25006 Improved sample for colorization with onnx model in cpp and python. Added a demo image in data folder for testing ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
b009a63e6b
commit
448375d1e7
@ -1,128 +1,117 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html
|
||||
// To download the onnx model, see: https://storage.googleapis.com/ailia-models/colorization/colorizer.onnx
|
||||
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
#include "common.hpp"
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <iostream>
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::dnn;
|
||||
using namespace std;
|
||||
using namespace cv::dnn;
|
||||
|
||||
// the 313 ab cluster centers from pts_in_hull.npy (already transposed)
|
||||
static float hull_pts[] = {
|
||||
-90., -90., -90., -90., -90., -80., -80., -80., -80., -80., -80., -80., -80., -70., -70., -70., -70., -70., -70., -70., -70.,
|
||||
-70., -70., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -60., -50., -50., -50., -50., -50., -50., -50., -50.,
|
||||
-50., -50., -50., -50., -50., -50., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -40., -30.,
|
||||
-30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -30., -20., -20., -20., -20., -20., -20., -20.,
|
||||
-20., -20., -20., -20., -20., -20., -20., -20., -20., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
|
||||
-10., -10., -10., -10., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10.,
|
||||
10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20., 20.,
|
||||
20., 20., 20., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 40., 40., 40., 40.,
|
||||
40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
|
||||
50., 50., 50., 50., 50., 50., 50., 50., 50., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60.,
|
||||
60., 60., 60., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 70., 80., 80., 80.,
|
||||
80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 90., 90., 90., 90., 90., 90., 90., 90., 90., 90.,
|
||||
90., 90., 90., 90., 90., 90., 90., 90., 90., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 50., 60., 70., 80., 90.,
|
||||
20., 30., 40., 50., 60., 70., 80., 90., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -20., -10., 0., 10., 20., 30., 40., 50.,
|
||||
60., 70., 80., 90., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -40., -30., -20., -10., 0., 10., 20.,
|
||||
30., 40., 50., 60., 70., 80., 90., 100., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -50.,
|
||||
-40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., -60., -50., -40., -30., -20., -10., 0., 10., 20.,
|
||||
30., 40., 50., 60., 70., 80., 90., 100., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.,
|
||||
100., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -80., -70., -60., -50.,
|
||||
-40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -90., -80., -70., -60., -50., -40., -30., -20., -10.,
|
||||
0., 10., 20., 30., 40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30.,
|
||||
40., 50., 60., 70., 80., 90., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70.,
|
||||
80., -110., -100., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100.,
|
||||
-90., -80., -70., -60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., 80., -110., -100., -90., -80., -70.,
|
||||
-60., -50., -40., -30., -20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -110., -100., -90., -80., -70., -60., -50., -40., -30.,
|
||||
-20., -10., 0., 10., 20., 30., 40., 50., 60., 70., -90., -80., -70., -60., -50., -40., -30., -20., -10., 0.
|
||||
};
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
int main(int argc, char** argv) {
|
||||
const string about =
|
||||
"This sample demonstrates recoloring grayscale images with dnn.\n"
|
||||
"This program is based on:\n"
|
||||
" http://richzhang.github.io/colorization\n"
|
||||
" https://github.com/richzhang/colorization\n"
|
||||
"Download caffemodel and prototxt files:\n"
|
||||
" http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2.caffemodel\n"
|
||||
" https://raw.githubusercontent.com/richzhang/colorization/caffe/models/colorization_deploy_v2.prototxt\n";
|
||||
const string keys =
|
||||
"{ h help | | print this help message }"
|
||||
"{ proto | colorization_deploy_v2.prototxt | model configuration }"
|
||||
"{ model | colorization_release_v2.caffemodel | model weights }"
|
||||
"{ image | space_shuttle.jpg | path to image file }"
|
||||
"{ opencl | | enable OpenCL }";
|
||||
"To download the onnx model:\n"
|
||||
" https://storage.googleapis.com/ailia-models/colorization/colorizer.onnx\n";
|
||||
|
||||
const string param_keys =
|
||||
"{ help h | | Print help message. }"
|
||||
"{ input i | baboon.jpg | Path to the input image }"
|
||||
"{ onnx_model_path | | Path to the ONNX model. Required. }";
|
||||
|
||||
const string backend_keys = format(
|
||||
"{ backend | 0 | Choose one of computation backends: "
|
||||
"%d: automatically (by default), "
|
||||
"%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
|
||||
"%d: OpenCV implementation, "
|
||||
"%d: VKCOM, "
|
||||
"%d: CUDA, "
|
||||
"%d: WebNN }",
|
||||
cv::dnn::DNN_BACKEND_DEFAULT, cv::dnn::DNN_BACKEND_INFERENCE_ENGINE, cv::dnn::DNN_BACKEND_OPENCV,
|
||||
cv::dnn::DNN_BACKEND_VKCOM, cv::dnn::DNN_BACKEND_CUDA, cv::dnn::DNN_BACKEND_WEBNN);
|
||||
const string target_keys = format(
|
||||
"{ target | 0 | Choose one of target computation devices: "
|
||||
"%d: CPU target (by default), "
|
||||
"%d: OpenCL, "
|
||||
"%d: OpenCL fp16 (half-float precision), "
|
||||
"%d: VPU, "
|
||||
"%d: Vulkan, "
|
||||
"%d: CUDA, "
|
||||
"%d: CUDA fp16 (half-float preprocess) }",
|
||||
cv::dnn::DNN_TARGET_CPU, cv::dnn::DNN_TARGET_OPENCL, cv::dnn::DNN_TARGET_OPENCL_FP16,
|
||||
cv::dnn::DNN_TARGET_MYRIAD, cv::dnn::DNN_TARGET_VULKAN, cv::dnn::DNN_TARGET_CUDA,
|
||||
cv::dnn::DNN_TARGET_CUDA_FP16);
|
||||
|
||||
const string keys = param_keys + backend_keys + target_keys;
|
||||
CommandLineParser parser(argc, argv, keys);
|
||||
parser.about(about);
|
||||
if (parser.has("help"))
|
||||
{
|
||||
|
||||
if (parser.has("help")) {
|
||||
parser.printMessage();
|
||||
return 0;
|
||||
}
|
||||
string modelTxt = samples::findFile(parser.get<string>("proto"));
|
||||
string modelBin = samples::findFile(parser.get<string>("model"));
|
||||
string imageFile = samples::findFile(parser.get<string>("image"));
|
||||
bool useOpenCL = parser.has("opencl");
|
||||
if (!parser.check())
|
||||
{
|
||||
parser.printErrors();
|
||||
return 1;
|
||||
|
||||
string inputImagePath = parser.get<string>("input");
|
||||
string onnxModelPath = parser.get<string>("onnx_model_path");
|
||||
int backendId = parser.get<int>("backend");
|
||||
int targetId = parser.get<int>("target");
|
||||
|
||||
if (onnxModelPath.empty()) {
|
||||
cerr << "The path to the ONNX model is required!" << endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
Mat img = imread(imageFile);
|
||||
if (img.empty())
|
||||
{
|
||||
cout << "Can't read image from file: " << imageFile << endl;
|
||||
return 2;
|
||||
Mat imgGray = imread(samples::findFile(inputImagePath), IMREAD_GRAYSCALE);
|
||||
if (imgGray.empty()) {
|
||||
cerr << "Could not read the image: " << inputImagePath << endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// fixed input size for the pretrained network
|
||||
const int W_in = 224;
|
||||
const int H_in = 224;
|
||||
Net net = dnn::readNetFromCaffe(modelTxt, modelBin);
|
||||
if (useOpenCL)
|
||||
net.setPreferableTarget(DNN_TARGET_OPENCL);
|
||||
Mat imgL = imgGray;
|
||||
imgL.convertTo(imgL, CV_32F, 100.0/255.0);
|
||||
Mat imgLResized;
|
||||
resize(imgL, imgLResized, Size(256, 256), 0, 0, INTER_CUBIC);
|
||||
|
||||
// setup additional layers:
|
||||
int sz[] = {2, 313, 1, 1};
|
||||
const Mat pts_in_hull(4, sz, CV_32F, hull_pts);
|
||||
Ptr<dnn::Layer> class8_ab = net.getLayer("class8_ab");
|
||||
class8_ab->blobs.push_back(pts_in_hull);
|
||||
Ptr<dnn::Layer> conv8_313_rh = net.getLayer("conv8_313_rh");
|
||||
conv8_313_rh->blobs.push_back(Mat(1, 313, CV_32F, Scalar(2.606)));
|
||||
// Prepare the model
|
||||
dnn::Net net = dnn::readNetFromONNX(onnxModelPath);
|
||||
net.setPreferableBackend(backendId);
|
||||
net.setPreferableTarget(targetId);
|
||||
//! [Read and initialize network]
|
||||
|
||||
// extract L channel and subtract mean
|
||||
Mat lab, L, input;
|
||||
img.convertTo(img, CV_32F, 1.0/255);
|
||||
cvtColor(img, lab, COLOR_BGR2Lab);
|
||||
extractChannel(lab, L, 0);
|
||||
resize(L, input, Size(W_in, H_in));
|
||||
input -= 50;
|
||||
// Create blob from the image
|
||||
Mat blob = dnn::blobFromImage(imgLResized, 1.0, Size(256, 256), Scalar(), false, false);
|
||||
|
||||
// run the L channel through the network
|
||||
Mat inputBlob = blobFromImage(input);
|
||||
net.setInput(inputBlob);
|
||||
net.setInput(blob);
|
||||
|
||||
// Run inference
|
||||
Mat result = net.forward();
|
||||
|
||||
// retrieve the calculated a,b channels from the network output
|
||||
Size siz(result.size[2], result.size[3]);
|
||||
Mat a = Mat(siz, CV_32F, result.ptr(0,0));
|
||||
Mat b = Mat(siz, CV_32F, result.ptr(0,1));
|
||||
resize(a, a, img.size());
|
||||
resize(b, b, img.size());
|
||||
Mat a(siz, CV_32F, result.ptr(0,0));
|
||||
Mat b(siz, CV_32F, result.ptr(0,1));
|
||||
resize(a, a, imgGray.size());
|
||||
resize(b, b, imgGray.size());
|
||||
|
||||
// merge, and convert back to BGR
|
||||
Mat color, chn[] = {L, a, b};
|
||||
Mat color, chn[] = {imgL, a, b};
|
||||
|
||||
// Proc
|
||||
Mat lab;
|
||||
merge(chn, 3, lab);
|
||||
cvtColor(lab, color, COLOR_Lab2BGR);
|
||||
|
||||
imshow("color", color);
|
||||
imshow("original", img);
|
||||
imshow("input image", imgGray);
|
||||
imshow("output image", color);
|
||||
waitKey();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,69 +1,88 @@
|
||||
# Script is based on https://github.com/richzhang/colorization/blob/master/colorization/colorize.py
|
||||
# To download the caffemodel and the prototxt, see: https://github.com/richzhang/colorization/tree/caffe/colorization/models
|
||||
# To download pts_in_hull.npy, see: https://github.com/richzhang/colorization/tree/caffe/colorization/resources/pts_in_hull.npy
|
||||
# To download the onnx model, see: https://storage.googleapis.com/ailia-models/colorization/colorizer.onnx
|
||||
# python colorization.py --onnx_model_path colorizer.onnx --input ansel_adams3.jpg
|
||||
import numpy as np
|
||||
import argparse
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='iColor: deep interactive colorization')
|
||||
parser.add_argument('--input', help='Path to image or video. Skip to capture frames from camera')
|
||||
parser.add_argument('--prototxt', help='Path to colorization_deploy_v2.prototxt', required=True)
|
||||
parser.add_argument('--caffemodel', help='Path to colorization_release_v2.caffemodel', required=True)
|
||||
parser.add_argument('--kernel', help='Path to pts_in_hull.npy', required=True)
|
||||
backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE,
|
||||
cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_VKCOM, cv.dnn.DNN_BACKEND_CUDA)
|
||||
targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD,
|
||||
cv.dnn.DNN_TARGET_HDDL, cv.dnn.DNN_TARGET_VULKAN, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16)
|
||||
|
||||
parser = argparse.ArgumentParser(description='iColor: deep interactive colorization')
|
||||
parser.add_argument('--input', default='baboon.jpg',help='Path to image.')
|
||||
parser.add_argument('--onnx_model_path', help='Path to onnx model', required=True)
|
||||
parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
|
||||
help="Choose one of computation backends: "
|
||||
"%d: automatically (by default), "
|
||||
"%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
|
||||
"%d: OpenCV implementation, "
|
||||
"%d: VKCOM, "
|
||||
"%d: CUDA" % backends)
|
||||
parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
|
||||
help='Choose one of target computation devices: '
|
||||
'%d: CPU target (by default), '
|
||||
'%d: OpenCL, '
|
||||
'%d: OpenCL fp16 (half-float precision), '
|
||||
'%d: NCS2 VPU, '
|
||||
'%d: HDDL VPU, '
|
||||
'%d: Vulkan, '
|
||||
'%d: CUDA, '
|
||||
'%d: CUDA fp16 (half-float preprocess)'% targets)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == '__main__':
|
||||
W_in = 224
|
||||
H_in = 224
|
||||
imshowSize = (640, 480)
|
||||
|
||||
args = parse_args()
|
||||
img_gray=cv.imread(cv.samples.findFile(args.input),cv.IMREAD_GRAYSCALE)
|
||||
|
||||
# Select desired model
|
||||
net = cv.dnn.readNetFromCaffe(args.prototxt, args.caffemodel)
|
||||
img_gray_rs = cv.resize(img_gray, (256, 256), interpolation=cv.INTER_CUBIC)
|
||||
img_gray_rs = img_gray_rs.astype(np.float32) # Convert to float to avoid data overflow
|
||||
img_gray_rs *= (100.0 / 255.0) # Scale L channel to 0-100 range
|
||||
|
||||
pts_in_hull = np.load(args.kernel) # load cluster centers
|
||||
onnx_model_path = args.onnx_model_path # Update this path to your ONNX model's path
|
||||
session = cv.dnn.readNetFromONNX(onnx_model_path)
|
||||
session.setPreferableBackend(args.backend)
|
||||
session.setPreferableTarget(args.target)
|
||||
|
||||
# populate cluster centers as 1x1 convolution kernel
|
||||
pts_in_hull = pts_in_hull.transpose().reshape(2, 313, 1, 1)
|
||||
net.getLayer(net.getLayerId('class8_ab')).blobs = [pts_in_hull.astype(np.float32)]
|
||||
net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, np.float32)]
|
||||
# Process each image in the batch (assuming batch processing is needed)
|
||||
blob = cv.dnn.blobFromImage(img_gray_rs, swapRB=False) # Adjust swapRB according to your model's training
|
||||
session.setInput(blob)
|
||||
result_numpy = np.array(session.forward()[0])
|
||||
|
||||
if args.input:
|
||||
cap = cv.VideoCapture(args.input)
|
||||
if result_numpy.shape[0] == 2:
|
||||
# Transpose result_numpy to shape (H, W, 2)
|
||||
ab = result_numpy.transpose((1, 2, 0))
|
||||
else:
|
||||
cap = cv.VideoCapture(0)
|
||||
# If it's already (H, W, 2), assign it directly
|
||||
ab = result_numpy
|
||||
|
||||
while cv.waitKey(1) < 0:
|
||||
hasFrame, frame = cap.read()
|
||||
if not hasFrame:
|
||||
cv.waitKey()
|
||||
break
|
||||
|
||||
img_rgb = (frame[:,:,[2, 1, 0]] * 1.0 / 255).astype(np.float32)
|
||||
# Resize ab to match img_gray's dimensions if they are not the same
|
||||
h, w = img_gray.shape
|
||||
if ab.shape[:2] != (h, w):
|
||||
ab_resized = cv.resize(ab, (w, h), interpolation=cv.INTER_LINEAR)
|
||||
else:
|
||||
ab_resized = ab
|
||||
|
||||
img_lab = cv.cvtColor(img_rgb, cv.COLOR_RGB2Lab)
|
||||
img_l = img_lab[:,:,0] # pull out L channel
|
||||
(H_orig,W_orig) = img_rgb.shape[:2] # original image size
|
||||
# Expand dimensions of L to match ab's dimensions
|
||||
img_l_expanded = np.expand_dims(img_gray, axis=-1)
|
||||
|
||||
# resize image to network input size
|
||||
img_rs = cv.resize(img_rgb, (W_in, H_in)) # resize image to network input size
|
||||
img_lab_rs = cv.cvtColor(img_rs, cv.COLOR_RGB2Lab)
|
||||
img_l_rs = img_lab_rs[:,:,0]
|
||||
img_l_rs -= 50 # subtract 50 for mean-centering
|
||||
# Concatenate L with AB to get the LAB image
|
||||
lab_image = np.concatenate((img_l_expanded, ab_resized), axis=-1)
|
||||
|
||||
net.setInput(cv.dnn.blobFromImage(img_l_rs))
|
||||
ab_dec = net.forward()[0,:,:,:].transpose((1,2,0)) # this is our result
|
||||
# Convert the Lab image to a 32-bit float format
|
||||
lab_image = lab_image.astype(np.float32)
|
||||
|
||||
(H_out,W_out) = ab_dec.shape[:2]
|
||||
ab_dec_us = cv.resize(ab_dec, (W_orig, H_orig))
|
||||
img_lab_out = np.concatenate((img_l[:,:,np.newaxis],ab_dec_us),axis=2) # concatenate with original image L
|
||||
img_bgr_out = np.clip(cv.cvtColor(img_lab_out, cv.COLOR_Lab2BGR), 0, 1)
|
||||
# Normalize L channel to the range [0, 100] and AB channels to the range [-127, 127]
|
||||
lab_image[:, :, 0] *= (100.0 / 255.0) # Rescale L channel
|
||||
#lab_image[:, :, 1:] -= 128 # Shift AB channels
|
||||
|
||||
frame = cv.resize(frame, imshowSize)
|
||||
cv.imshow('origin', frame)
|
||||
cv.imshow('gray', cv.cvtColor(frame, cv.COLOR_RGB2GRAY))
|
||||
cv.imshow('colorized', cv.resize(img_bgr_out, imshowSize))
|
||||
# Convert the LAB image to BGR
|
||||
image_bgr_out = cv.cvtColor(lab_image, cv.COLOR_Lab2BGR)
|
||||
cv.imshow("input image",img_gray)
|
||||
cv.imshow("output image",image_bgr_out)
|
||||
cv.waitKey(0)
|
Loading…
Reference in New Issue
Block a user