|
|
|
@ -8,8 +8,7 @@ from typing import Union |
|
|
|
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir |
|
|
|
|
from ultralytics.hub.utils import HUB_WEB_ROOT |
|
|
|
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load |
|
|
|
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load |
|
|
|
|
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml |
|
|
|
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load |
|
|
|
|
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -139,7 +138,7 @@ class Model(nn.Module): |
|
|
|
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) |
|
|
|
|
self.ckpt_path = self.model.pt_path |
|
|
|
|
else: |
|
|
|
|
weights = check_file(weights) |
|
|
|
|
weights = checks.check_file(weights) |
|
|
|
|
self.model, self.ckpt = weights, None |
|
|
|
|
self.task = task or guess_model_task(weights) |
|
|
|
|
self.ckpt_path = weights |
|
|
|
@ -204,11 +203,11 @@ class Model(nn.Module): |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
source (str | int | PIL | np.ndarray): The source of the image to make predictions on. |
|
|
|
|
Accepts all source types accepted by the YOLO model. |
|
|
|
|
Accepts all source types accepted by the YOLO model. |
|
|
|
|
stream (bool): Whether to stream the predictions or not. Defaults to False. |
|
|
|
|
predictor (BasePredictor): Customized predictor. |
|
|
|
|
**kwargs : Additional keyword arguments passed to the predictor. |
|
|
|
|
Check the 'configuration' section in the documentation for all available options. |
|
|
|
|
Check the 'configuration' section in the documentation for all available options. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(List[ultralytics.engine.results.Results]): The prediction results. |
|
|
|
@ -251,8 +250,7 @@ class Model(nn.Module): |
|
|
|
|
if not hasattr(self.predictor, 'trackers'): |
|
|
|
|
from ultralytics.trackers import register_tracker |
|
|
|
|
register_tracker(self, persist) |
|
|
|
|
# ByteTrack-based method needs low confidence predictions as input |
|
|
|
|
kwargs['conf'] = kwargs.get('conf') or 0.1 |
|
|
|
|
kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input |
|
|
|
|
kwargs['mode'] = 'track' |
|
|
|
|
return self.predict(source=source, stream=stream, **kwargs) |
|
|
|
|
|
|
|
|
@ -266,7 +264,6 @@ class Model(nn.Module): |
|
|
|
|
""" |
|
|
|
|
custom = {'rect': True} # method defaults |
|
|
|
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right |
|
|
|
|
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1) |
|
|
|
|
|
|
|
|
|
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks) |
|
|
|
|
validator(model=self.model) |
|
|
|
@ -321,9 +318,9 @@ class Model(nn.Module): |
|
|
|
|
if any(kwargs): |
|
|
|
|
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') |
|
|
|
|
kwargs = self.session.train_args |
|
|
|
|
check_pip_update_available() |
|
|
|
|
checks.check_pip_update_available() |
|
|
|
|
|
|
|
|
|
overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides |
|
|
|
|
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides |
|
|
|
|
custom = {'data': TASK2DATA[self.task]} # method defaults |
|
|
|
|
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right |
|
|
|
|
if args.get('resume'): |
|
|
|
@ -366,7 +363,7 @@ class Model(nn.Module): |
|
|
|
|
self._check_is_pytorch_model() |
|
|
|
|
self = super()._apply(fn) # noqa |
|
|
|
|
self.predictor = None # reset predictor as device may have changed |
|
|
|
|
self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0' |
|
|
|
|
self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|