|
|
|
@ -73,9 +73,7 @@ class BasePredictor: |
|
|
|
|
data (dict): Data configuration. |
|
|
|
|
device (torch.device): Device used for prediction. |
|
|
|
|
dataset (Dataset): Dataset used for prediction. |
|
|
|
|
vid_path (str): Path to video file. |
|
|
|
|
vid_writer (cv2.VideoWriter): Video writer for saving video output. |
|
|
|
|
data_path (str): Path to data. |
|
|
|
|
vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
|
|
|
@ -100,10 +98,11 @@ class BasePredictor: |
|
|
|
|
self.imgsz = None |
|
|
|
|
self.device = None |
|
|
|
|
self.dataset = None |
|
|
|
|
self.vid_path, self.vid_writer, self.vid_frame = None, None, None |
|
|
|
|
self.vid_writer = {} # dict of {save_path: video_writer, ...} |
|
|
|
|
self.plotted_img = None |
|
|
|
|
self.data_path = None |
|
|
|
|
self.source_type = None |
|
|
|
|
self.seen = 0 |
|
|
|
|
self.windows = [] |
|
|
|
|
self.batch = None |
|
|
|
|
self.results = None |
|
|
|
|
self.transforms = None |
|
|
|
@ -155,44 +154,6 @@ class BasePredictor: |
|
|
|
|
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride) |
|
|
|
|
return [letterbox(image=x) for x in im] |
|
|
|
|
|
|
|
|
|
def write_results(self, idx, results, batch): |
|
|
|
|
"""Write inference results to a file or directory.""" |
|
|
|
|
p, im, _ = batch |
|
|
|
|
log_string = "" |
|
|
|
|
if len(im.shape) == 3: |
|
|
|
|
im = im[None] # expand for batch dim |
|
|
|
|
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 |
|
|
|
|
log_string += f"{idx}: " |
|
|
|
|
frame = self.dataset.count |
|
|
|
|
else: |
|
|
|
|
frame = getattr(self.dataset, "frame", 0) |
|
|
|
|
self.data_path = p |
|
|
|
|
self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}") |
|
|
|
|
log_string += "%gx%g " % im.shape[2:] # print string |
|
|
|
|
result = results[idx] |
|
|
|
|
log_string += result.verbose() |
|
|
|
|
|
|
|
|
|
if self.args.save or self.args.show: # Add bbox to image |
|
|
|
|
plot_args = { |
|
|
|
|
"line_width": self.args.line_width, |
|
|
|
|
"boxes": self.args.show_boxes, |
|
|
|
|
"conf": self.args.show_conf, |
|
|
|
|
"labels": self.args.show_labels, |
|
|
|
|
} |
|
|
|
|
if not self.args.retina_masks: |
|
|
|
|
plot_args["im_gpu"] = im[idx] |
|
|
|
|
self.plotted_img = result.plot(**plot_args) |
|
|
|
|
# Write |
|
|
|
|
if self.args.save_txt: |
|
|
|
|
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) |
|
|
|
|
if self.args.save_crop: |
|
|
|
|
result.save_crop( |
|
|
|
|
save_dir=self.save_dir / "crops", |
|
|
|
|
file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return log_string |
|
|
|
|
|
|
|
|
|
def postprocess(self, preds, img, orig_imgs): |
|
|
|
|
"""Post-processes predictions for an image and returns them.""" |
|
|
|
|
return preds |
|
|
|
@ -228,18 +189,20 @@ class BasePredictor: |
|
|
|
|
else None |
|
|
|
|
) |
|
|
|
|
self.dataset = load_inference_source( |
|
|
|
|
source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer |
|
|
|
|
source=source, |
|
|
|
|
batch=self.args.batch, |
|
|
|
|
vid_stride=self.args.vid_stride, |
|
|
|
|
buffer=self.args.stream_buffer, |
|
|
|
|
) |
|
|
|
|
self.source_type = self.dataset.source_type |
|
|
|
|
if not getattr(self, "stream", True) and ( |
|
|
|
|
self.dataset.mode == "stream" # streams |
|
|
|
|
or len(self.dataset) > 1000 # images |
|
|
|
|
self.source_type.stream |
|
|
|
|
or self.source_type.screenshot |
|
|
|
|
or len(self.dataset) > 1000 # many images |
|
|
|
|
or any(getattr(self.dataset, "video_flag", [False])) |
|
|
|
|
): # videos |
|
|
|
|
LOGGER.warning(STREAM_WARNING) |
|
|
|
|
self.vid_path = [None] * self.dataset.bs |
|
|
|
|
self.vid_writer = [None] * self.dataset.bs |
|
|
|
|
self.vid_frame = [None] * self.dataset.bs |
|
|
|
|
self.vid_writer = {} |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def stream_inference(self, source=None, model=None, *args, **kwargs): |
|
|
|
@ -271,10 +234,9 @@ class BasePredictor: |
|
|
|
|
ops.Profile(device=self.device), |
|
|
|
|
) |
|
|
|
|
self.run_callbacks("on_predict_start") |
|
|
|
|
for batch in self.dataset: |
|
|
|
|
for self.batch in self.dataset: |
|
|
|
|
self.run_callbacks("on_predict_batch_start") |
|
|
|
|
self.batch = batch |
|
|
|
|
path, im0s, vid_cap, s = batch |
|
|
|
|
paths, im0s, is_video, s = self.batch |
|
|
|
|
|
|
|
|
|
# Preprocess |
|
|
|
|
with profilers[0]: |
|
|
|
@ -290,8 +252,8 @@ class BasePredictor: |
|
|
|
|
# Postprocess |
|
|
|
|
with profilers[2]: |
|
|
|
|
self.results = self.postprocess(preds, im, im0s) |
|
|
|
|
|
|
|
|
|
self.run_callbacks("on_predict_postprocess_end") |
|
|
|
|
|
|
|
|
|
# Visualize, save, write results |
|
|
|
|
n = len(im0s) |
|
|
|
|
for i in range(n): |
|
|
|
@ -301,41 +263,32 @@ class BasePredictor: |
|
|
|
|
"inference": profilers[1].dt * 1e3 / n, |
|
|
|
|
"postprocess": profilers[2].dt * 1e3 / n, |
|
|
|
|
} |
|
|
|
|
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy() |
|
|
|
|
p = Path(p) |
|
|
|
|
|
|
|
|
|
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: |
|
|
|
|
s += self.write_results(i, self.results, (p, im, im0)) |
|
|
|
|
if self.args.save or self.args.save_txt: |
|
|
|
|
self.results[i].save_dir = self.save_dir.__str__() |
|
|
|
|
if self.args.show and self.plotted_img is not None: |
|
|
|
|
self.show(p) |
|
|
|
|
if self.args.save and self.plotted_img is not None: |
|
|
|
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name)) |
|
|
|
|
s[i] += self.write_results(i, Path(paths[i]), im, is_video) |
|
|
|
|
|
|
|
|
|
# Print batch results |
|
|
|
|
if self.args.verbose: |
|
|
|
|
LOGGER.info("\n".join(s)) |
|
|
|
|
|
|
|
|
|
self.run_callbacks("on_predict_batch_end") |
|
|
|
|
yield from self.results |
|
|
|
|
|
|
|
|
|
# Print time (inference-only) |
|
|
|
|
if self.args.verbose: |
|
|
|
|
LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms") |
|
|
|
|
|
|
|
|
|
# Release assets |
|
|
|
|
if isinstance(self.vid_writer[-1], cv2.VideoWriter): |
|
|
|
|
self.vid_writer[-1].release() # release final video writer |
|
|
|
|
for v in self.vid_writer.values(): |
|
|
|
|
if isinstance(v, cv2.VideoWriter): |
|
|
|
|
v.release() |
|
|
|
|
|
|
|
|
|
# Print results |
|
|
|
|
# Print final results |
|
|
|
|
if self.args.verbose and self.seen: |
|
|
|
|
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image |
|
|
|
|
LOGGER.info( |
|
|
|
|
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " |
|
|
|
|
f"{(1, 3, *im.shape[2:])}" % t |
|
|
|
|
f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t |
|
|
|
|
) |
|
|
|
|
if self.args.save or self.args.save_txt or self.args.save_crop: |
|
|
|
|
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels |
|
|
|
|
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" |
|
|
|
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") |
|
|
|
|
|
|
|
|
|
self.run_callbacks("on_predict_end") |
|
|
|
|
|
|
|
|
|
def setup_model(self, model, verbose=True): |
|
|
|
@ -354,48 +307,81 @@ class BasePredictor: |
|
|
|
|
self.args.half = self.model.fp16 # update half |
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
def show(self, p): |
|
|
|
|
"""Display an image in a window using OpenCV imshow().""" |
|
|
|
|
im0 = self.plotted_img |
|
|
|
|
if platform.system() == "Linux" and p not in self.windows: |
|
|
|
|
self.windows.append(p) |
|
|
|
|
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) |
|
|
|
|
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) |
|
|
|
|
cv2.imshow(str(p), im0) |
|
|
|
|
cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond |
|
|
|
|
def write_results(self, i, p, im, is_video): |
|
|
|
|
"""Write inference results to a file or directory.""" |
|
|
|
|
string = "" # print string |
|
|
|
|
if len(im.shape) == 3: |
|
|
|
|
im = im[None] # expand for batch dim |
|
|
|
|
if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 |
|
|
|
|
string += f"{i}: " |
|
|
|
|
frame = self.dataset.count |
|
|
|
|
else: |
|
|
|
|
frame = getattr(self.dataset, "frame", 0) - len(self.results) + i |
|
|
|
|
|
|
|
|
|
self.txt_path = self.save_dir / "labels" / (p.stem + f"_{frame}" if is_video[i] else "") |
|
|
|
|
string += "%gx%g " % im.shape[2:] |
|
|
|
|
result = self.results[i] |
|
|
|
|
result.save_dir = self.save_dir.__str__() # used in other locations |
|
|
|
|
string += result.verbose() + f"{result.speed['inference']:.1f}ms" |
|
|
|
|
|
|
|
|
|
# Add predictions to image |
|
|
|
|
if self.args.save or self.args.show: |
|
|
|
|
self.plotted_img = result.plot( |
|
|
|
|
line_width=self.args.line_width, |
|
|
|
|
boxes=self.args.show_boxes, |
|
|
|
|
conf=self.args.show_conf, |
|
|
|
|
labels=self.args.show_labels, |
|
|
|
|
im_gpu=None if self.args.retina_masks else im[i], |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Save results |
|
|
|
|
if self.args.save_txt: |
|
|
|
|
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) |
|
|
|
|
if self.args.save_crop: |
|
|
|
|
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) |
|
|
|
|
if self.args.show: |
|
|
|
|
self.show(str(p), is_video[i]) |
|
|
|
|
if self.args.save: |
|
|
|
|
self.save_predicted_images(str(self.save_dir / p.name), is_video[i], frame) |
|
|
|
|
|
|
|
|
|
return string |
|
|
|
|
|
|
|
|
|
def save_preds(self, vid_cap, idx, save_path): |
|
|
|
|
def save_predicted_images(self, save_path="", is_video=False, frame=0): |
|
|
|
|
"""Save video predictions as mp4 at specified path.""" |
|
|
|
|
im0 = self.plotted_img |
|
|
|
|
# Save imgs |
|
|
|
|
if self.dataset.mode == "image": |
|
|
|
|
cv2.imwrite(save_path, im0) |
|
|
|
|
else: # 'video' or 'stream' |
|
|
|
|
im = self.plotted_img |
|
|
|
|
|
|
|
|
|
# Save videos and streams |
|
|
|
|
if is_video: |
|
|
|
|
frames_path = f'{save_path.split(".", 1)[0]}_frames/' |
|
|
|
|
if self.vid_path[idx] != save_path: # new video |
|
|
|
|
self.vid_path[idx] = save_path |
|
|
|
|
if save_path not in self.vid_writer: # new video |
|
|
|
|
if self.args.save_frames: |
|
|
|
|
Path(frames_path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
self.vid_frame[idx] = 0 |
|
|
|
|
if isinstance(self.vid_writer[idx], cv2.VideoWriter): |
|
|
|
|
self.vid_writer[idx].release() # release previous video writer |
|
|
|
|
if vid_cap: # video |
|
|
|
|
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec |
|
|
|
|
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
|
|
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
else: # stream |
|
|
|
|
fps, w, h = 30, im0.shape[1], im0.shape[0] |
|
|
|
|
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") |
|
|
|
|
self.vid_writer[idx] = cv2.VideoWriter( |
|
|
|
|
str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h) |
|
|
|
|
self.vid_writer[save_path] = cv2.VideoWriter( |
|
|
|
|
filename=str(Path(save_path).with_suffix(suffix)), |
|
|
|
|
fourcc=cv2.VideoWriter_fourcc(*fourcc), |
|
|
|
|
fps=30, # integer required, floats produce error in MP4 codec |
|
|
|
|
frameSize=(im.shape[1], im.shape[0]), # (width, height) |
|
|
|
|
) |
|
|
|
|
# Write video |
|
|
|
|
self.vid_writer[idx].write(im0) |
|
|
|
|
|
|
|
|
|
# Write frame |
|
|
|
|
# Save video |
|
|
|
|
self.vid_writer[save_path].write(im) |
|
|
|
|
if self.args.save_frames: |
|
|
|
|
cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0) |
|
|
|
|
self.vid_frame[idx] += 1 |
|
|
|
|
cv2.imwrite(f"{frames_path}{frame}.jpg", im) |
|
|
|
|
|
|
|
|
|
# Save images |
|
|
|
|
else: |
|
|
|
|
cv2.imwrite(save_path, im) |
|
|
|
|
|
|
|
|
|
def show(self, p="", is_video=False): |
|
|
|
|
"""Display an image in a window using OpenCV imshow().""" |
|
|
|
|
im = self.plotted_img |
|
|
|
|
if platform.system() == "Linux" and p not in self.windows: |
|
|
|
|
self.windows.append(p) |
|
|
|
|
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) |
|
|
|
|
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) |
|
|
|
|
cv2.imshow(p, im) |
|
|
|
|
cv2.waitKey(1 if is_video else 500) # 1 millisecond |
|
|
|
|
|
|
|
|
|
def run_callbacks(self, event: str): |
|
|
|
|
"""Runs all registered callbacks for a specific event.""" |
|
|
|
|