mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Merge pull request #24245 from alexlyulkov/al/update-fast-neural-style-dnn-sample
Replaced torch7 model by ONNX model in fast-neural-style dnn sample
This commit is contained in:
commit
1a8d37d19e
@ -5,15 +5,15 @@ import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='This script is used to run style transfer models from '
|
||||
'https://github.com/jcjohnson/fast-neural-style using OpenCV')
|
||||
'https://github.com/onnx/models/tree/main/vision/style_transfer/fast_neural_style using OpenCV')
|
||||
parser.add_argument('--input', help='Path to image or video. Skip to capture frames from camera')
|
||||
parser.add_argument('--model', help='Path to .t7 model')
|
||||
parser.add_argument('--model', help='Path to .onnx model')
|
||||
parser.add_argument('--width', default=-1, type=int, help='Resize input to specific width.')
|
||||
parser.add_argument('--height', default=-1, type=int, help='Resize input to specific height.')
|
||||
parser.add_argument('--median_filter', default=0, type=int, help='Kernel size of postprocessing blurring.')
|
||||
args = parser.parse_args()
|
||||
|
||||
net = cv.dnn.readNetFromTorch(cv.samples.findFile(args.model))
|
||||
net = cv.dnn.readNet(cv.samples.findFile(args.model))
|
||||
net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV)
|
||||
|
||||
if args.input:
|
||||
@ -31,16 +31,12 @@ while cv.waitKey(1) < 0:
|
||||
inWidth = args.width if args.width != -1 else frame.shape[1]
|
||||
inHeight = args.height if args.height != -1 else frame.shape[0]
|
||||
inp = cv.dnn.blobFromImage(frame, 1.0, (inWidth, inHeight),
|
||||
(103.939, 116.779, 123.68), swapRB=False, crop=False)
|
||||
swapRB=True, crop=False)
|
||||
|
||||
net.setInput(inp)
|
||||
out = net.forward()
|
||||
|
||||
out = out.reshape(3, out.shape[2], out.shape[3])
|
||||
out[0] += 103.939
|
||||
out[1] += 116.779
|
||||
out[2] += 123.68
|
||||
out /= 255
|
||||
out = out.transpose(1, 2, 0)
|
||||
|
||||
t, _ = net.getPerfProfile()
|
||||
@ -50,4 +46,7 @@ while cv.waitKey(1) < 0:
|
||||
if args.median_filter:
|
||||
out = cv.medianBlur(out, args.median_filter)
|
||||
|
||||
out = np.clip(out, 0, 255)
|
||||
out = out.astype(np.uint8)
|
||||
|
||||
cv.imshow('Styled image', out)
|
||||
|
Loading…
Reference in New Issue
Block a user