|
|
|
@ -12,8 +12,8 @@ from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, Bot |
|
|
|
|
GhostBottleneck, GhostConv, Segment) |
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load |
|
|
|
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible, |
|
|
|
|
model_info, scale_img, time_sync) |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, |
|
|
|
|
intersect_dicts, make_divisible, model_info, scale_img, time_sync) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseModel(nn.Module): |
|
|
|
@ -100,6 +100,10 @@ class BaseModel(nn.Module): |
|
|
|
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv |
|
|
|
|
delattr(m, 'bn') # remove batchnorm |
|
|
|
|
m.forward = m.forward_fuse # update forward |
|
|
|
|
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'): |
|
|
|
|
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) |
|
|
|
|
delattr(m, 'bn') # remove batchnorm |
|
|
|
|
m.forward = m.forward_fuse # update forward |
|
|
|
|
self.info() |
|
|
|
|
|
|
|
|
|
return self |
|
|
|
|