Speedup the Video Inference by Accelerating data-loading Stage (#7832)
* add a faster inference for video * Fix typos * modify typo * modify the numpy array to torch gpu * fix lint * add description * add documents * fix typro * fix lint * fix lint * fix lint again * fix a mistakepull/7492/head^2
parent
280cc7d74f
commit
b1f40efb09
3 changed files with 166 additions and 0 deletions
@ -0,0 +1,113 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import argparse |
||||
|
||||
import cv2 |
||||
import mmcv |
||||
import numpy as np |
||||
import torch |
||||
from torchvision.transforms import functional as F |
||||
|
||||
from mmdet.apis import init_detector |
||||
from mmdet.datasets.pipelines import Compose |
||||
|
||||
try: |
||||
import ffmpegcv |
||||
except ImportError: |
||||
raise ImportError( |
||||
'Please install ffmpegcv with:\n\n pip install ffmpegcv') |
||||
|
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser( |
||||
description='MMDetection video demo with GPU acceleration') |
||||
parser.add_argument('video', help='Video file') |
||||
parser.add_argument('config', help='Config file') |
||||
parser.add_argument('checkpoint', help='Checkpoint file') |
||||
parser.add_argument( |
||||
'--device', default='cuda:0', help='Device used for inference') |
||||
parser.add_argument( |
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold') |
||||
parser.add_argument('--out', type=str, help='Output video file') |
||||
parser.add_argument('--show', action='store_true', help='Show video') |
||||
parser.add_argument( |
||||
'--nvdecode', action='store_true', help='Use NVIDIA decoder') |
||||
parser.add_argument( |
||||
'--wait-time', |
||||
type=float, |
||||
default=1, |
||||
help='The interval of show (s), 0 is block') |
||||
args = parser.parse_args() |
||||
return args |
||||
|
||||
|
||||
def prefetch_img_metas(cfg, ori_wh): |
||||
w, h = ori_wh |
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' |
||||
test_pipeline = Compose(cfg.data.test.pipeline) |
||||
data = {'img': np.zeros((h, w, 3), dtype=np.uint8)} |
||||
data = test_pipeline(data) |
||||
img_metas = data['img_metas'][0].data |
||||
return img_metas |
||||
|
||||
|
||||
def process_img(frame_resize, img_metas, device): |
||||
assert frame_resize.shape == img_metas['pad_shape'] |
||||
frame_cuda = torch.from_numpy(frame_resize).to(device).float() |
||||
frame_cuda = frame_cuda.permute(2, 0, 1) # HWC to CHW |
||||
mean = torch.from_numpy(img_metas['img_norm_cfg']['mean']).to(device) |
||||
std = torch.from_numpy(img_metas['img_norm_cfg']['std']).to(device) |
||||
frame_cuda = F.normalize(frame_cuda, mean=mean, std=std, inplace=True) |
||||
frame_cuda = frame_cuda[None, :, :, :] # NCHW |
||||
data = {'img': [frame_cuda], 'img_metas': [[img_metas]]} |
||||
return data |
||||
|
||||
|
||||
def main(): |
||||
args = parse_args() |
||||
assert args.out or args.show, \ |
||||
('Please specify at least one operation (save/show the ' |
||||
'video) with the argument "--out" or "--show"') |
||||
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device) |
||||
|
||||
if args.nvdecode: |
||||
VideoCapture = ffmpegcv.VideoCaptureNV |
||||
else: |
||||
VideoCapture = ffmpegcv.VideoCapture |
||||
video_origin = VideoCapture(args.video) |
||||
img_metas = prefetch_img_metas(model.cfg, |
||||
(video_origin.width, video_origin.height)) |
||||
resize_wh = img_metas['pad_shape'][1::-1] |
||||
video_resize = VideoCapture( |
||||
args.video, |
||||
resize=resize_wh, |
||||
resize_keepratio=True, |
||||
resize_keepratioalign='topleft', |
||||
pix_fmt='rgb24') |
||||
video_writer = None |
||||
if args.out: |
||||
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps) |
||||
|
||||
with torch.no_grad(): |
||||
for frame_resize, frame_origin in zip( |
||||
mmcv.track_iter_progress(video_resize), video_origin): |
||||
data = process_img(frame_resize, img_metas, args.device) |
||||
result = model(return_loss=False, rescale=True, **data)[0] |
||||
frame_mask = model.show_result( |
||||
frame_origin, result, score_thr=args.score_thr) |
||||
if args.show: |
||||
cv2.namedWindow('video', 0) |
||||
mmcv.imshow(frame_mask, 'video', args.wait_time) |
||||
if args.out: |
||||
video_writer.write(frame_mask) |
||||
|
||||
if video_writer: |
||||
video_writer.release() |
||||
video_origin.release() |
||||
video_resize.release() |
||||
|
||||
cv2.destroyAllWindows() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
main() |
Loading…
Reference in new issue