|
|
|
@ -10,11 +10,13 @@ import torchvision |
|
|
|
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, |
|
|
|
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, |
|
|
|
|
GhostBottleneck, GhostConv, Segment) |
|
|
|
|
from ultralytics.yolo.utils import LOGGER, colorstr, yaml_load |
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load |
|
|
|
|
from ultralytics.yolo.utils.checks import check_yaml |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible, |
|
|
|
|
model_info, scale_img, time_sync) |
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseModel(nn.Module): |
|
|
|
|
''' |
|
|
|
@ -211,7 +213,7 @@ class DetectionModel(BaseModel): |
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
def load(self, weights, verbose=True): |
|
|
|
|
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32 |
|
|
|
|
csd = weights.float().state_dict() # checkpoint state_dict as FP32 |
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect |
|
|
|
|
self.load_state_dict(csd, strict=False) # load |
|
|
|
|
if verbose: |
|
|
|
@ -281,21 +283,21 @@ class ClassificationModel(BaseModel): |
|
|
|
|
# Functions ------------------------------------------------------------------------------------------------------------ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=True): |
|
|
|
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): |
|
|
|
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a |
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download |
|
|
|
|
default_keys = DEFAULT_CONFIG_DICT.keys() |
|
|
|
|
|
|
|
|
|
model = Ensemble() |
|
|
|
|
for w in weights if isinstance(weights, list) else [weights]: |
|
|
|
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load |
|
|
|
|
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} |
|
|
|
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model |
|
|
|
|
|
|
|
|
|
# Model compatibility updates |
|
|
|
|
if not hasattr(ckpt, 'stride'): |
|
|
|
|
ckpt.stride = torch.tensor([32.]) |
|
|
|
|
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)): |
|
|
|
|
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict |
|
|
|
|
ckpt.args = {k: v for k, v in args.items() if k in default_keys} |
|
|
|
|
|
|
|
|
|
# Append |
|
|
|
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode |
|
|
|
|
|
|
|
|
|
# Module compatibility updates |
|
|
|
@ -310,7 +312,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True): |
|
|
|
|
if len(model) == 1: |
|
|
|
|
return model[-1] |
|
|
|
|
|
|
|
|
|
# Return detection ensemble |
|
|
|
|
# Return ensemble |
|
|
|
|
print(f'Ensemble created with {weights}\n') |
|
|
|
|
for k in 'names', 'nc', 'yaml': |
|
|
|
|
setattr(model, k, getattr(model[0], k)) |
|
|
|
|