Predictor support (#65)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/70/head
parent
479992093c
commit
e6737f1207
22 changed files with 916 additions and 48 deletions
After Width: | Height: | Size: 476 KiB |
After Width: | Height: | Size: 165 KiB |
@ -0,0 +1,201 @@ |
||||
# predictor engine by Ultralytics |
||||
""" |
||||
Run prection on images, videos, directories, globs, YouTube, webcam, streams, etc. |
||||
Usage - sources: |
||||
$ yolo task=... mode=predict model=s.pt --source 0 # webcam |
||||
img.jpg # image |
||||
vid.mp4 # video |
||||
screen # screenshot |
||||
path/ # directory |
||||
list.txt # list of images |
||||
list.streams # list of streams |
||||
'path/*.jpg' # glob |
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube |
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream |
||||
Usage - formats: |
||||
$ yolo task=... mode=predict --weights yolov5s.pt # PyTorch |
||||
yolov5s.torchscript # TorchScript |
||||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn |
||||
yolov5s_openvino_model # OpenVINO |
||||
yolov5s.engine # TensorRT |
||||
yolov5s.mlmodel # CoreML (macOS-only) |
||||
yolov5s_saved_model # TensorFlow SavedModel |
||||
yolov5s.pb # TensorFlow GraphDef |
||||
yolov5s.tflite # TensorFlow Lite |
||||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU |
||||
yolov5s_paddle_model # PaddlePaddle |
||||
""" |
||||
import platform |
||||
from pathlib import Path |
||||
|
||||
import cv2 |
||||
import torch |
||||
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams |
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS, check_dataset, check_dataset_yaml |
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr, ops |
||||
from ultralytics.yolo.utils.checks import check_file, check_imshow |
||||
from ultralytics.yolo.utils.configs import get_config |
||||
from ultralytics.yolo.utils.files import increment_path |
||||
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend |
||||
from ultralytics.yolo.utils.plotting import Annotator |
||||
from ultralytics.yolo.utils.torch_utils import check_img_size, select_device, smart_inference_mode |
||||
|
||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" |
||||
|
||||
|
||||
class BasePredictor: |
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}): |
||||
self.args = get_config(config, overrides) |
||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) |
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) |
||||
|
||||
self.done_setup = False |
||||
|
||||
# Usable if setup is done |
||||
self.model = None |
||||
self.data = self.args.data # data_dict |
||||
self.device = None |
||||
self.dataset = None |
||||
self.vid_path, self.vid_writer = None, None |
||||
self.view_img = None |
||||
self.annotator = None |
||||
self.data_path = None |
||||
|
||||
def preprocess(self, img): |
||||
pass |
||||
|
||||
def get_annotator(self, img): |
||||
raise NotImplementedError("get_annotator function needs to be implemented") |
||||
|
||||
def write_results(self, pred, batch, print_string): |
||||
raise NotImplementedError("print_results function needs to be implemented") |
||||
|
||||
def postprocess(self, preds, img, orig_img): |
||||
return preds |
||||
|
||||
def setup(self, source=None, model=None): |
||||
# source |
||||
source = str(source or self.args.source) |
||||
self.save_img = not self.args.nosave and not source.endswith('.txt') |
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) |
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) |
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) |
||||
screenshot = source.lower().startswith('screen') |
||||
if is_url and is_file: |
||||
source = check_file(source) # download |
||||
|
||||
# data |
||||
if self.data: |
||||
if self.data.endswith(".yaml"): |
||||
self.data = check_dataset_yaml(self.data) |
||||
else: |
||||
self.data = check_dataset(self.data) |
||||
|
||||
# model |
||||
device = select_device(self.args.device) |
||||
model = model or self.args.model |
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA |
||||
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) # NOTE: not passing data |
||||
stride, pt = model.stride, model.pt |
||||
imgsz = check_img_size(self.args.img_size, s=stride) # check image size |
||||
|
||||
# Dataloader |
||||
bs = 1 # batch_size |
||||
if webcam: |
||||
self.view_img = check_imshow(warn=True) |
||||
self.dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride) |
||||
bs = len(self.dataset) |
||||
elif screenshot: |
||||
self.dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) |
||||
else: |
||||
self.dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride) |
||||
self.vid_path, self.vid_writer = [None] * bs, [None] * bs |
||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup |
||||
|
||||
self.model = model |
||||
self.webcam = webcam |
||||
self.screenshot = screenshot |
||||
self.imgsz = imgsz |
||||
self.done_setup = True |
||||
self.device = device |
||||
|
||||
return model |
||||
|
||||
@smart_inference_mode() |
||||
def __call__(self, source=None, model=None): |
||||
if not self.done_setup: |
||||
model = self.setup(source, model) |
||||
else: |
||||
model = self.model |
||||
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) |
||||
for batch in self.dataset: |
||||
path, im, im0s, vid_cap, s = batch |
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False |
||||
with self.dt[0]: |
||||
im = self.preprocess(im) |
||||
if len(im.shape) == 3: |
||||
im = im[None] # expand for batch dim |
||||
|
||||
# Inference |
||||
with self.dt[1]: |
||||
preds = model(im, augment=self.args.augment, visualize=visualize) |
||||
|
||||
# postprocess |
||||
with self.dt[2]: |
||||
preds = self.postprocess(preds, im, im0s) |
||||
|
||||
for i in range(len(im)): |
||||
if self.webcam: |
||||
path, im0s = path[i], im0s[i] |
||||
p = Path(path) |
||||
s += self.write_results(i, preds, (p, im, im0s)) |
||||
|
||||
if self.args.view_img: |
||||
self.show(p) |
||||
|
||||
if self.save_img: |
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name)) |
||||
|
||||
# Print time (inference-only) |
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") |
||||
|
||||
# Print results |
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image |
||||
LOGGER.info( |
||||
f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}' |
||||
% t) |
||||
if self.args.save_txt or self.save_img: |
||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' |
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") |
||||
|
||||
def show(self, p): |
||||
im0 = self.annotator.result() |
||||
if platform.system() == 'Linux' and p not in self.windows: |
||||
self.windows.append(p) |
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) |
||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) |
||||
cv2.imshow(str(p), im0) |
||||
cv2.waitKey(1) # 1 millisecond |
||||
|
||||
def save_preds(self, vid_cap, idx, save_path): |
||||
im0 = self.annotator.result() |
||||
# save imgs |
||||
if self.dataset.mode == 'image': |
||||
cv2.imwrite(save_path, im0) |
||||
else: # 'video' or 'stream' |
||||
if self.vid_path[idx] != save_path: # new video |
||||
self.vid_path[idx] = save_path |
||||
if isinstance(self.vid_writer[idx], cv2.VideoWriter): |
||||
self.vid_writer[idx].release() # release previous video writer |
||||
if vid_cap: # video |
||||
fps = vid_cap.get(cv2.CAP_PROP_FPS) |
||||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
||||
else: # stream |
||||
fps, w, h = 30, im0.shape[1], im0.shape[0] |
||||
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos |
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) |
||||
self.vid_writer[idx].write(im0) |
@ -0,0 +1,23 @@ |
||||
from pathlib import Path |
||||
from typing import Dict, Union |
||||
|
||||
from omegaconf import DictConfig, OmegaConf |
||||
|
||||
|
||||
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): |
||||
""" |
||||
Accepts yaml file name or DictConfig containing experiment configuration. |
||||
Returns training args namespace |
||||
:param config: Optional file name or DictConfig object |
||||
""" |
||||
if isinstance(config, (str, Path)): |
||||
config = OmegaConf.load(config) |
||||
elif isinstance(config, Dict): |
||||
config = OmegaConf.create(config) |
||||
# override |
||||
if isinstance(overrides, str): |
||||
overrides = OmegaConf.load(overrides) |
||||
elif isinstance(overrides, Dict): |
||||
overrides = OmegaConf.create(overrides) |
||||
|
||||
return OmegaConf.merge(config, overrides) |
@ -0,0 +1,68 @@ |
||||
import hydra |
||||
import torch |
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor |
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG |
||||
from ultralytics.yolo.utils import ops |
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box |
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor): |
||||
|
||||
def get_annotator(self, img): |
||||
return Annotator(img, example=str(self.model.names), pil=True) |
||||
|
||||
def preprocess(self, img): |
||||
img = torch.Tensor(img).to(self.model.device) |
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 |
||||
return img |
||||
|
||||
def write_results(self, idx, preds, batch): |
||||
p, im, im0 = batch |
||||
log_string = "" |
||||
if len(im.shape) == 3: |
||||
im = im[None] # expand for batch dim |
||||
self.seen += 1 |
||||
im0 = im0.copy() |
||||
if self.webcam: # batch_size >= 1 |
||||
log_string += f'{idx}: ' |
||||
frame = self.dataset.cound |
||||
else: |
||||
frame = getattr(self.dataset, 'frame', 0) |
||||
|
||||
self.data_path = p |
||||
# save_path = str(self.save_dir / p.name) # im.jpg |
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') |
||||
log_string += '%gx%g ' % im.shape[2:] # print string |
||||
self.annotator = self.get_annotator(im0) |
||||
|
||||
prob = preds[idx] |
||||
# Print results |
||||
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices |
||||
log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " |
||||
|
||||
# write |
||||
text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i) |
||||
if self.save_img or self.args.view_img: # Add bbox to image |
||||
self.annotator.text((32, 32), text, txt_color=(255, 255, 255)) |
||||
if self.args.save_txt: # Write to file |
||||
with open(f'{self.txt_path}.txt', 'a') as f: |
||||
f.write(text + '\n') |
||||
|
||||
return log_string |
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) |
||||
def predict(cfg): |
||||
cfg.model = cfg.model or "squeezenet1_0" |
||||
sz = cfg.img_size |
||||
if type(sz) != int: # recieved listConfig |
||||
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand |
||||
else: |
||||
cfg.img_size = [sz, sz] |
||||
predictor = ClassificationPredictor(cfg) |
||||
predictor() |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
predict() |
@ -1,2 +1,3 @@ |
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor, predict |
||||
from ultralytics.yolo.v8.detect.train import DetectionTrainer, train |
||||
from ultralytics.yolo.v8.detect.val import DetectionValidator, val |
||||
|
@ -0,0 +1,97 @@ |
||||
import hydra |
||||
import torch |
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor |
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG |
||||
from ultralytics.yolo.utils import ops |
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box |
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor): |
||||
|
||||
def get_annotator(self, img): |
||||
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names)) |
||||
|
||||
def preprocess(self, img): |
||||
img = torch.from_numpy(img).to(self.model.device) |
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 |
||||
img /= 255 # 0 - 255 to 0.0 - 1.0 |
||||
return img |
||||
|
||||
def postprocess(self, preds, img, orig_img): |
||||
preds = ops.non_max_suppression(preds, |
||||
self.args.conf_thres, |
||||
self.args.iou_thres, |
||||
agnostic=self.args.agnostic_nms, |
||||
max_det=self.args.max_det) |
||||
|
||||
for i, pred in enumerate(preds): |
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape |
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() |
||||
|
||||
return preds |
||||
|
||||
def write_results(self, idx, preds, batch): |
||||
p, im, im0 = batch |
||||
log_string = "" |
||||
if len(im.shape) == 3: |
||||
im = im[None] # expand for batch dim |
||||
self.seen += 1 |
||||
im0 = im0.copy() |
||||
if self.webcam: # batch_size >= 1 |
||||
log_string += f'{idx}: ' |
||||
frame = self.dataset.count |
||||
else: |
||||
frame = getattr(self.dataset, 'frame', 0) |
||||
|
||||
self.data_path = p |
||||
# save_path = str(self.save_dir / p.name) # im.jpg |
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') |
||||
log_string += '%gx%g ' % im.shape[2:] # print string |
||||
self.annotator = self.get_annotator(im0) |
||||
|
||||
det = preds[idx] |
||||
if len(det) == 0: |
||||
return log_string |
||||
for c in det[:, 5].unique(): |
||||
n = (det[:, 5] == c).sum() # detections per class |
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " |
||||
|
||||
# write |
||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh |
||||
for *xyxy, conf, cls in reversed(det): |
||||
if self.args.save_txt: # Write to file |
||||
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh |
||||
line = (cls, *xywh, conf) if self.args.save_conf else (cls, *xywh) # label format |
||||
with open(f'{self.txt_path}.txt', 'a') as f: |
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n') |
||||
|
||||
if self.save_img or self.args.save_crop or self.args.view_img: # Add bbox to image |
||||
c = int(cls) # integer class |
||||
label = None if self.args.hide_labels else ( |
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') |
||||
self.annotator.box_label(xyxy, label, color=colors(c, True)) |
||||
if self.args.save_crop: |
||||
imc = im0.copy() |
||||
save_one_box(xyxy, |
||||
imc, |
||||
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg', |
||||
BGR=True) |
||||
|
||||
return log_string |
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) |
||||
def predict(cfg): |
||||
cfg.model = cfg.model or "n.pt" |
||||
sz = cfg.img_size |
||||
if type(sz) != int: # recieved listConfig |
||||
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand |
||||
else: |
||||
cfg.img_size = [sz, sz] |
||||
predictor = DetectionPredictor(cfg) |
||||
predictor() |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
predict() |
@ -1,2 +1,3 @@ |
||||
from ultralytics.yolo.v8.segment.predict import SegmentationPredictor, predict |
||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train |
||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val |
||||
|
@ -0,0 +1,115 @@ |
||||
from pathlib import Path |
||||
|
||||
import hydra |
||||
import torch |
||||
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG |
||||
from ultralytics.yolo.utils import ROOT, ops |
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box |
||||
|
||||
from ..detect.predict import DetectionPredictor |
||||
|
||||
|
||||
class SegmentationPredictor(DetectionPredictor): |
||||
|
||||
def postprocess(self, preds, img, orig_img): |
||||
masks = [] |
||||
if len(preds) == 2: # eval |
||||
p, proto, = preds |
||||
else: # len(3) train |
||||
p, proto, _ = preds |
||||
# TODO: filter by classes |
||||
p = ops.non_max_suppression(p, |
||||
self.args.conf_thres, |
||||
self.args.iou_thres, |
||||
agnostic=self.args.agnostic_nms, |
||||
max_det=self.args.max_det, |
||||
nm=32) |
||||
for i, pred in enumerate(p): |
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape |
||||
if not len(pred): |
||||
continue |
||||
if self.args.retina_masks: |
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() |
||||
masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC |
||||
else: |
||||
masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC |
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() |
||||
|
||||
return (p, masks) |
||||
|
||||
def write_results(self, idx, preds, batch): |
||||
p, im, im0 = batch |
||||
log_string = "" |
||||
if len(im.shape) == 3: |
||||
im = im[None] # expand for batch dim |
||||
self.seen += 1 |
||||
if self.webcam: # batch_size >= 1 |
||||
log_string += f'{idx}: ' |
||||
frame = self.dataset.count |
||||
else: |
||||
frame = getattr(self.dataset, 'frame', 0) |
||||
|
||||
self.data_path = p |
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') |
||||
log_string += '%gx%g ' % im.shape[2:] # print string |
||||
self.annotator = self.get_annotator(im0) |
||||
|
||||
preds, masks = preds |
||||
det = preds[idx] |
||||
if len(det) == 0: |
||||
return log_string |
||||
# Segments |
||||
mask = masks[idx] |
||||
if self.args.save_txt: |
||||
segments = [ |
||||
ops.scale_segments(im0.shape if self.arg.retina_masks else im.shape[2:], x, im0.shape, normalize=True) |
||||
for x in reversed(ops.masks2segments(mask))] |
||||
|
||||
# Print results |
||||
for c in det[:, 5].unique(): |
||||
n = (det[:, 5] == c).sum() # detections per class |
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # add to string |
||||
|
||||
# Mask plotting |
||||
self.annotator.masks( |
||||
mask, |
||||
colors=[colors(x, True) for x in det[:, 5]], |
||||
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() / |
||||
255 if self.args.retina_masks else im[idx]) |
||||
|
||||
# Write results |
||||
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])): |
||||
if self.args.save_txt: # Write to file |
||||
seg = segments[j].reshape(-1) # (n,2) to (n*2) |
||||
line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format |
||||
with open(f'{self.txt_path}.txt', 'a') as f: |
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n') |
||||
|
||||
if self.save_img or self.args.save_crop or self.args.view_img: |
||||
c = int(cls) # integer class |
||||
label = None if self.args.hide_labels else ( |
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') |
||||
self.annotator.box_label(xyxy, label, color=colors(c, True)) |
||||
# annotator.draw.polygon(segments[j], outline=colors(c, True), width=3) |
||||
if self.args.save_crop: |
||||
imc = im0.copy() |
||||
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.model.names[c] / f'{p.stem}.jpg', BGR=True) |
||||
|
||||
return log_string |
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) |
||||
def predict(cfg): |
||||
cfg.model = cfg.model or "n.pt" |
||||
sz = cfg.img_size |
||||
if type(sz) != int: # recieved listConfig |
||||
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand |
||||
else: |
||||
cfg.img_size = [sz, sz] |
||||
predictor = SegmentationPredictor(cfg) |
||||
predictor() |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
predict() |
Loading…
Reference in new issue