|
|
|
@ -98,7 +98,7 @@ class BasePredictor: |
|
|
|
|
self.imgsz = None |
|
|
|
|
self.device = None |
|
|
|
|
self.dataset = None |
|
|
|
|
self.vid_path, self.vid_writer = None, None |
|
|
|
|
self.vid_path, self.vid_writer, self.vid_frame = None, None, None |
|
|
|
|
self.plotted_img = None |
|
|
|
|
self.data_path = None |
|
|
|
|
self.source_type = None |
|
|
|
@ -221,7 +221,9 @@ class BasePredictor: |
|
|
|
|
len(self.dataset) > 1000 or # images |
|
|
|
|
any(getattr(self.dataset, 'video_flag', [False]))): # videos |
|
|
|
|
LOGGER.warning(STREAM_WARNING) |
|
|
|
|
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs |
|
|
|
|
self.vid_path = [None] * self.dataset.bs |
|
|
|
|
self.vid_writer = [None] * self.dataset.bs |
|
|
|
|
self.vid_frame = [None] * self.dataset.bs |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def stream_inference(self, source=None, model=None, *args, **kwargs): |
|
|
|
@ -341,8 +343,11 @@ class BasePredictor: |
|
|
|
|
if self.dataset.mode == 'image': |
|
|
|
|
cv2.imwrite(save_path, im0) |
|
|
|
|
else: # 'video' or 'stream' |
|
|
|
|
frames_path = f'{save_path.split(".", 1)[0]}_frames/' |
|
|
|
|
if self.vid_path[idx] != save_path: # new video |
|
|
|
|
Path(frames_path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
self.vid_path[idx] = save_path |
|
|
|
|
self.vid_frame[idx] = 0 |
|
|
|
|
if isinstance(self.vid_writer[idx], cv2.VideoWriter): |
|
|
|
|
self.vid_writer[idx].release() # release previous video writer |
|
|
|
|
if vid_cap: # video |
|
|
|
@ -352,10 +357,15 @@ class BasePredictor: |
|
|
|
|
else: # stream |
|
|
|
|
fps, w, h = 30, im0.shape[1], im0.shape[0] |
|
|
|
|
suffix, fourcc = ('.mp4', 'avc1') if MACOS else ('.avi', 'WMV2') if WINDOWS else ('.avi', 'MJPG') |
|
|
|
|
save_path = str(Path(save_path).with_suffix(suffix)) |
|
|
|
|
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) |
|
|
|
|
self.vid_writer[idx] = cv2.VideoWriter(str(Path(save_path).with_suffix(suffix)), |
|
|
|
|
cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) |
|
|
|
|
# Write video |
|
|
|
|
self.vid_writer[idx].write(im0) |
|
|
|
|
|
|
|
|
|
# Write frame |
|
|
|
|
cv2.imwrite(f'{frames_path}{self.vid_frame[idx]}.jpg', im0) |
|
|
|
|
self.vid_frame[idx] += 1 |
|
|
|
|
|
|
|
|
|
def run_callbacks(self, event: str): |
|
|
|
|
"""Runs all registered callbacks for a specific event.""" |
|
|
|
|
for callback in self.callbacks.get(event, []): |
|
|
|
|