|
|
|
@ -30,6 +30,7 @@ Usage - formats: |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
import platform |
|
|
|
|
import re |
|
|
|
|
import threading |
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
@ -236,7 +237,7 @@ class BasePredictor: |
|
|
|
|
self.run_callbacks("on_predict_start") |
|
|
|
|
for self.batch in self.dataset: |
|
|
|
|
self.run_callbacks("on_predict_batch_start") |
|
|
|
|
paths, im0s, is_video, s = self.batch |
|
|
|
|
paths, im0s, s = self.batch |
|
|
|
|
|
|
|
|
|
# Preprocess |
|
|
|
|
with profilers[0]: |
|
|
|
@ -264,7 +265,7 @@ class BasePredictor: |
|
|
|
|
"postprocess": profilers[2].dt * 1e3 / n, |
|
|
|
|
} |
|
|
|
|
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: |
|
|
|
|
s[i] += self.write_results(i, Path(paths[i]), im, is_video) |
|
|
|
|
s[i] += self.write_results(i, Path(paths[i]), im, s) |
|
|
|
|
|
|
|
|
|
# Print batch results |
|
|
|
|
if self.args.verbose: |
|
|
|
@ -308,7 +309,7 @@ class BasePredictor: |
|
|
|
|
self.args.half = self.model.fp16 # update half |
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
def write_results(self, i, p, im, is_video): |
|
|
|
|
def write_results(self, i, p, im, s): |
|
|
|
|
"""Write inference results to a file or directory.""" |
|
|
|
|
string = "" # print string |
|
|
|
|
if len(im.shape) == 3: |
|
|
|
@ -317,9 +318,10 @@ class BasePredictor: |
|
|
|
|
string += f"{i}: " |
|
|
|
|
frame = self.dataset.count |
|
|
|
|
else: |
|
|
|
|
frame = getattr(self.dataset, "frame", 0) - len(self.results) + i |
|
|
|
|
match = re.search(r"frame (\d+)/", s[i]) |
|
|
|
|
frame = int(match.group(1)) if match else None # 0 if frame undetermined |
|
|
|
|
|
|
|
|
|
self.txt_path = self.save_dir / "labels" / (p.stem + (f"_{frame}" if is_video[i] else "")) |
|
|
|
|
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) |
|
|
|
|
string += "%gx%g " % im.shape[2:] |
|
|
|
|
result = self.results[i] |
|
|
|
|
result.save_dir = self.save_dir.__str__() # used in other locations |
|
|
|
@ -341,18 +343,19 @@ class BasePredictor: |
|
|
|
|
if self.args.save_crop: |
|
|
|
|
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) |
|
|
|
|
if self.args.show: |
|
|
|
|
self.show(str(p), is_video[i]) |
|
|
|
|
self.show(str(p)) |
|
|
|
|
if self.args.save: |
|
|
|
|
self.save_predicted_images(str(self.save_dir / p.name), is_video[i], frame) |
|
|
|
|
self.save_predicted_images(str(self.save_dir / p.name), frame) |
|
|
|
|
|
|
|
|
|
return string |
|
|
|
|
|
|
|
|
|
def save_predicted_images(self, save_path="", is_video=False, frame=0): |
|
|
|
|
def save_predicted_images(self, save_path="", frame=0): |
|
|
|
|
"""Save video predictions as mp4 at specified path.""" |
|
|
|
|
im = self.plotted_img |
|
|
|
|
|
|
|
|
|
# Save videos and streams |
|
|
|
|
if is_video: |
|
|
|
|
if self.dataset.mode in {"stream", "video"}: |
|
|
|
|
fps = self.dataset.fps if self.dataset.mode == "video" else 30 |
|
|
|
|
frames_path = f'{save_path.split(".", 1)[0]}_frames/' |
|
|
|
|
if save_path not in self.vid_writer: # new video |
|
|
|
|
if self.args.save_frames: |
|
|
|
@ -361,7 +364,7 @@ class BasePredictor: |
|
|
|
|
self.vid_writer[save_path] = cv2.VideoWriter( |
|
|
|
|
filename=str(Path(save_path).with_suffix(suffix)), |
|
|
|
|
fourcc=cv2.VideoWriter_fourcc(*fourcc), |
|
|
|
|
fps=30, # integer required, floats produce error in MP4 codec |
|
|
|
|
fps=fps, # integer required, floats produce error in MP4 codec |
|
|
|
|
frameSize=(im.shape[1], im.shape[0]), # (width, height) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -374,7 +377,7 @@ class BasePredictor: |
|
|
|
|
else: |
|
|
|
|
cv2.imwrite(save_path, im) |
|
|
|
|
|
|
|
|
|
def show(self, p="", is_video=False): |
|
|
|
|
def show(self, p=""): |
|
|
|
|
"""Display an image in a window using OpenCV imshow().""" |
|
|
|
|
im = self.plotted_img |
|
|
|
|
if platform.system() == "Linux" and p not in self.windows: |
|
|
|
@ -382,7 +385,7 @@ class BasePredictor: |
|
|
|
|
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) |
|
|
|
|
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) |
|
|
|
|
cv2.imshow(p, im) |
|
|
|
|
cv2.waitKey(1 if is_video else 500) # 1 millisecond |
|
|
|
|
cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond |
|
|
|
|
|
|
|
|
|
def run_callbacks(self, event: str): |
|
|
|
|
"""Runs all registered callbacks for a specific event.""" |
|
|
|
|