You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.0 KiB
60 lines
2.0 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import torch |
|
from timm import create_model |
|
from timm.loss import SoftTargetCrossEntropy |
|
from timm.models.layers import drop |
|
|
|
|
|
from models.convnext import ConvNeXt |
|
from models.resnet import ResNet |
|
from models.custom import YourConvNet |
|
_import_resnets_for_timm_registration = (ResNet,) |
|
|
|
|
|
# log more |
|
def _ex_repr(self): |
|
return ', '.join( |
|
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) |
|
for k, v in vars(self).items() |
|
if not k.startswith('_') and k != 'training' |
|
and not isinstance(v, (torch.nn.Module, torch.Tensor)) |
|
) |
|
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): |
|
if hasattr(clz, 'extra_repr'): |
|
clz.extra_repr = _ex_repr |
|
else: |
|
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' |
|
|
|
|
|
pretrain_default_model_kwargs = { |
|
'your_convnet': dict(), |
|
'resnet50': dict(drop_path_rate=0.05), |
|
'resnet101': dict(drop_path_rate=0.08), |
|
'resnet152': dict(drop_path_rate=0.10), |
|
'resnet200': dict(drop_path_rate=0.15), |
|
'convnext_small': dict(sparse=True, drop_path_rate=0.2), |
|
'convnext_base': dict(sparse=True, drop_path_rate=0.3), |
|
'convnext_large': dict(sparse=True, drop_path_rate=0.4), |
|
} |
|
for kw in pretrain_default_model_kwargs.values(): |
|
kw['pretrained'] = False |
|
kw['num_classes'] = 0 |
|
kw['global_pool'] = '' |
|
|
|
|
|
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): |
|
from encoder import SparseEncoder |
|
|
|
kwargs = pretrain_default_model_kwargs[name] |
|
if drop_path_rate != 0: |
|
kwargs['drop_path_rate'] = drop_path_rate |
|
print(f'[build_sparse_encoder] model kwargs={kwargs}') |
|
cnn = create_model(name, **kwargs) |
|
|
|
return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose) |
|
|
|
|