|
|
|
@ -11,8 +11,8 @@ import torch.nn as nn |
|
|
|
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, |
|
|
|
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, |
|
|
|
|
GhostBottleneck, GhostConv, Segment) |
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load |
|
|
|
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml |
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load |
|
|
|
|
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, |
|
|
|
|
intersect_dicts, make_divisible, model_info, scale_img, time_sync) |
|
|
|
|
|
|
|
|
@ -151,15 +151,19 @@ class BaseModel(nn.Module): |
|
|
|
|
m.strides = fn(m.strides) |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
def load(self, weights): |
|
|
|
|
""" |
|
|
|
|
This function loads the weights of the model from a file |
|
|
|
|
def load(self, weights, verbose=True): |
|
|
|
|
"""Load the weights into the model. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
weights (str): The weights to load into the model. |
|
|
|
|
weights (dict) or (torch.nn.Module): The pre-trained weights to be loaded. |
|
|
|
|
verbose (bool, optional): Whether to log the transfer progress. Defaults to True. |
|
|
|
|
""" |
|
|
|
|
# Force all tasks to implement this function |
|
|
|
|
raise NotImplementedError('This function needs to be implemented by derived classes!') |
|
|
|
|
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts |
|
|
|
|
csd = model.float().state_dict() # checkpoint state_dict as FP32 |
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect |
|
|
|
|
self.load_state_dict(csd, strict=False) # load |
|
|
|
|
if verbose: |
|
|
|
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectionModel(BaseModel): |
|
|
|
@ -234,13 +238,6 @@ class DetectionModel(BaseModel): |
|
|
|
|
y[-1] = y[-1][..., i:] # small |
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
def load(self, weights, verbose=True): |
|
|
|
|
csd = weights.float().state_dict() # checkpoint state_dict as FP32 |
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect |
|
|
|
|
self.load_state_dict(csd, strict=False) # load |
|
|
|
|
if verbose and RANK == -1: |
|
|
|
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentationModel(DetectionModel): |
|
|
|
|
# YOLOv8 segmentation model |
|
|
|
@ -293,12 +290,6 @@ class ClassificationModel(BaseModel): |
|
|
|
|
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict |
|
|
|
|
self.info() |
|
|
|
|
|
|
|
|
|
def load(self, weights): |
|
|
|
|
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts |
|
|
|
|
csd = model.float().state_dict() |
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect |
|
|
|
|
self.load_state_dict(csd, strict=False) # load |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def reshape_outputs(model, nc): |
|
|
|
|
# Update a TorchVision classification model to class count 'n' if required |
|
|
|
@ -338,6 +329,7 @@ def torch_safe_load(weight): |
|
|
|
|
""" |
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download_asset |
|
|
|
|
|
|
|
|
|
check_suffix(file=weight, suffix='.pt') |
|
|
|
|
file = attempt_download_asset(weight) # search online if missing locally |
|
|
|
|
try: |
|
|
|
|
return torch.load(file, map_location='cpu'), file # load |
|
|
|
|