You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
3.0 KiB
89 lines
3.0 KiB
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
from paddlers.models.ppdet.core.workspace import load_config, merge_config, create |
|
from paddlers.models.ppdet.utils.checkpoint import load_weight, load_pretrain_weight |
|
from paddlers.models.ppdet.utils.logger import setup_logger |
|
from paddlers.models.ppdet.core.workspace import register, serializable |
|
|
|
from paddle.utils import try_import |
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
@register |
|
@serializable |
|
class OFA(object): |
|
def __init__(self, ofa_config): |
|
super(OFA, self).__init__() |
|
self.ofa_config = ofa_config |
|
|
|
def __call__(self, model, param_state_dict): |
|
|
|
paddleslim = try_import('paddleslim') |
|
from paddleslim.nas.ofa import OFA, RunConfig, utils |
|
from paddleslim.nas.ofa.convert_super import Convert, supernet |
|
task = self.ofa_config['task'] |
|
expand_ratio = self.ofa_config['expand_ratio'] |
|
|
|
skip_neck = self.ofa_config['skip_neck'] |
|
skip_head = self.ofa_config['skip_head'] |
|
|
|
run_config = self.ofa_config['RunConfig'] |
|
if 'skip_layers' in run_config: |
|
skip_layers = run_config['skip_layers'] |
|
else: |
|
skip_layers = [] |
|
|
|
# supernet config |
|
sp_config = supernet(expand_ratio=expand_ratio) |
|
# convert to supernet |
|
model = Convert(sp_config).convert(model) |
|
|
|
skip_names = [] |
|
if skip_neck: |
|
skip_names.append('neck.') |
|
if skip_head: |
|
skip_names.append('head.') |
|
|
|
for name, sublayer in model.named_sublayers(): |
|
for n in skip_names: |
|
if n in name: |
|
skip_layers.append(name) |
|
|
|
run_config['skip_layers'] = skip_layers |
|
run_config = RunConfig(**run_config) |
|
|
|
# build ofa model |
|
ofa_model = OFA(model, run_config=run_config) |
|
|
|
ofa_model.set_epoch(0) |
|
ofa_model.set_task(task) |
|
|
|
input_spec = [{ |
|
"image": paddle.ones( |
|
shape=[1, 3, 640, 640], dtype='float32'), |
|
"im_shape": paddle.full( |
|
[1, 2], 640, dtype='float32'), |
|
"scale_factor": paddle.ones( |
|
shape=[1, 2], dtype='float32') |
|
}] |
|
|
|
ofa_model._clear_search_space(input_spec=input_spec) |
|
ofa_model._build_ss = True |
|
check_ss = ofa_model._sample_config('expand_ratio', phase=None) |
|
# tokenize the search space |
|
ofa_model.tokenize() |
|
# check token map, search cands and search space |
|
logger.info('Token map is {}'.format(ofa_model.token_map)) |
|
logger.info('Search candidates is {}'.format(ofa_model.search_cands)) |
|
logger.info('The length of search_space is {}, search_space is {}'. |
|
format(len(ofa_model._ofa_layers), ofa_model._ofa_layers)) |
|
# set model state dict into ofa model |
|
utils.set_state_dict(ofa_model.model, param_state_dict) |
|
return ofa_model
|
|
|