From ceeb01dce5f6358df0c7b784b04fead14603a85d Mon Sep 17 00:00:00 2001 From: Alexander Lyulkov Date: Fri, 8 Sep 2023 12:44:22 +0700 Subject: [PATCH] Replaced torch7 by onnx model in fast-neural-style dnn sample --- samples/dnn/fast_neural_style.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/samples/dnn/fast_neural_style.py b/samples/dnn/fast_neural_style.py index 912c2f0832..22b8217b3a 100644 --- a/samples/dnn/fast_neural_style.py +++ b/samples/dnn/fast_neural_style.py @@ -5,15 +5,15 @@ import argparse parser = argparse.ArgumentParser( description='This script is used to run style transfer models from ' - 'https://github.com/jcjohnson/fast-neural-style using OpenCV') + 'https://github.com/onnx/models/tree/main/vision/style_transfer/fast_neural_style using OpenCV') parser.add_argument('--input', help='Path to image or video. Skip to capture frames from camera') -parser.add_argument('--model', help='Path to .t7 model') +parser.add_argument('--model', help='Path to .onnx model') parser.add_argument('--width', default=-1, type=int, help='Resize input to specific width.') parser.add_argument('--height', default=-1, type=int, help='Resize input to specific height.') parser.add_argument('--median_filter', default=0, type=int, help='Kernel size of postprocessing blurring.') args = parser.parse_args() -net = cv.dnn.readNetFromTorch(cv.samples.findFile(args.model)) +net = cv.dnn.readNetFromONNX(cv.samples.findFile(args.model)) net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV) if args.input: @@ -31,16 +31,12 @@ while cv.waitKey(1) < 0: inWidth = args.width if args.width != -1 else frame.shape[1] inHeight = args.height if args.height != -1 else frame.shape[0] inp = cv.dnn.blobFromImage(frame, 1.0, (inWidth, inHeight), - (103.939, 116.779, 123.68), swapRB=False, crop=False) + swapRB=True, crop=False) net.setInput(inp) out = net.forward() out = out.reshape(3, out.shape[2], out.shape[3]) - out[0] += 103.939 - out[1] += 116.779 - out[2] += 123.68 - out /= 255 out = out.transpose(1, 2, 0) t, _ = net.getPerfProfile() @@ -50,4 +46,7 @@ while cv.waitKey(1) < 0: if args.median_filter: out = cv.medianBlur(out, args.median_filter) + out = np.clip(out, 0, 255) + out = out.astype(np.uint8) + cv.imshow('Styled image', out)