You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

60 lines
2.0 KiB

# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from timm import create_model
from timm.loss import SoftTargetCrossEntropy
from timm.models.layers import drop
from models.convnext import ConvNeXt
from models.resnet import ResNet
from models.custom import YourConvNet
_import_resnets_for_timm_registration = (ResNet,)
# log more
def _ex_repr(self):
return ', '.join(
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
for k, v in vars(self).items()
if not k.startswith('_') and k != 'training'
and not isinstance(v, (torch.nn.Module, torch.Tensor))
)
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
if hasattr(clz, 'extra_repr'):
clz.extra_repr = _ex_repr
else:
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
pretrain_default_model_kwargs = {
'your_convnet': dict(),
'resnet50': dict(drop_path_rate=0.05),
'resnet101': dict(drop_path_rate=0.08),
'resnet152': dict(drop_path_rate=0.10),
'resnet200': dict(drop_path_rate=0.15),
'convnext_small': dict(sparse=True, drop_path_rate=0.2),
'convnext_base': dict(sparse=True, drop_path_rate=0.3),
'convnext_large': dict(sparse=True, drop_path_rate=0.4),
}
for kw in pretrain_default_model_kwargs.values():
kw['pretrained'] = False
kw['num_classes'] = 0
kw['global_pool'] = ''
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
from encoder import SparseEncoder
kwargs = pretrain_default_model_kwargs[name]
if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate
print(f'[build_sparse_encoder] model kwargs={kwargs}')
cnn = create_model(name, **kwargs)
return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)