You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

588 lines
23 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import time
import copy
import math
import json
from functools import partial, wraps
from inspect import signature
import yaml
import paddle
from paddle.io import DataLoader, DistributedBatchSampler
import paddlers_slim as paddlers
import paddlers_slim.utils.logging as logging
from paddlers_slim.utils import (
seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights,
load_pretrain_weights, load_checkpoint, SmoothedValue, TrainingStats,
_get_shared_memory_size_in_M, EarlyStop, to_data_parallel, scheduler_step)
class ModelMeta(type):
def __new__(cls, name, bases, attrs):
def _deco(init_func):
@wraps(init_func)
def _wrapper(self, *args, **kwargs):
if hasattr(self, '_raw_params'):
ret = init_func(self, *args, **kwargs)
else:
sig = signature(init_func)
bnd_args = sig.bind(self, *args, **kwargs)
raw_params = bnd_args.arguments
raw_params.pop('self')
self._raw_params = raw_params
ret = init_func(self, *args, **kwargs)
return ret
return _wrapper
old_init_func = attrs['__init__']
attrs['__init__'] = _deco(old_init_func)
return type.__new__(cls, name, bases, attrs)
class BaseModel(metaclass=ModelMeta):
def __init__(self, model_type):
self.model_type = model_type
self.in_channels = None
self.num_classes = None
self.labels = None
self.version = paddlers.__version__
self.net = None
self.optimizer = None
self.test_inputs = None
self.train_data_loader = None
self.eval_data_loader = None
self.eval_metrics = None
self.best_accuracy = -1.
self.best_model_epoch = -1
# Whether to use synchronized BN
self.sync_bn = False
self.status = 'Normal'
# The initial epoch when training is resumed
self.completed_epochs = 0
self.pruner = None
self.pruning_ratios = None
self.quantizer = None
self.quant_config = None
self.fixed_input_shape = None
def initialize_net(self,
pretrain_weights=None,
save_dir='.',
resume_checkpoint=None,
is_backbone_weights=False,
load_optim_state=True):
if pretrain_weights is not None and \
not osp.exists(pretrain_weights):
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if self.model_type == 'classifier':
pretrain_weights = get_pretrain_weights(
pretrain_weights, self.model_name, save_dir)
else:
backbone_name = getattr(self, 'backbone_name', None)
pretrain_weights = get_pretrain_weights(
pretrain_weights,
self.__class__.__name__,
save_dir,
backbone_name=backbone_name)
if pretrain_weights is not None:
if is_backbone_weights:
load_pretrain_weights(
self.net.backbone,
pretrain_weights,
model_name='backbone of ' + self.model_name)
else:
load_pretrain_weights(
self.net, pretrain_weights, model_name=self.model_name)
if resume_checkpoint is not None:
if not osp.exists(resume_checkpoint):
logging.error(
"The checkpoint path {} to resume training from does not exist."
.format(resume_checkpoint),
exit=True)
if not osp.exists(osp.join(resume_checkpoint, 'model.pdparams')):
logging.error(
"Model parameter state dictionary file 'model.pdparams' "
"was not found in given checkpoint path {}!".format(
resume_checkpoint),
exit=True)
if not osp.exists(osp.join(resume_checkpoint, 'model.pdopt')):
logging.error(
"Optimizer state dictionary file 'model.pdparams' "
"was not found in given checkpoint path {}!".format(
resume_checkpoint),
exit=True)
if not osp.exists(osp.join(resume_checkpoint, 'model.yml')):
logging.error(
"'model.yml' was not found in given checkpoint path {}!".
format(resume_checkpoint),
exit=True)
with open(osp.join(resume_checkpoint, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader)
self.completed_epochs = info['completed_epochs']
self.best_accuracy = info['_Attributes']['best_accuracy']
self.best_model_epoch = info['_Attributes']['best_model_epoch']
load_checkpoint(
self.net,
self.optimizer,
model_name=self.model_name,
checkpoint=resume_checkpoint,
load_optim_state=load_optim_state)
def get_model_info(self, get_raw_params=False, inplace=True):
if inplace:
init_params = self.init_params
else:
init_params = copy.deepcopy(self.init_params)
info = dict()
info['version'] = paddlers.__version__
info['Model'] = self.__class__.__name__
info['_Attributes'] = dict(
[('model_type', self.model_type), ('in_channels', self.in_channels),
('num_classes', self.num_classes), ('labels', self.labels),
('fixed_input_shape', self.fixed_input_shape),
('best_accuracy', self.best_accuracy),
('best_model_epoch', self.best_model_epoch)])
if 'self' in init_params:
del init_params['self']
if '__class__' in init_params:
del init_params['__class__']
if 'model_name' in init_params:
del init_params['model_name']
if 'params' in init_params:
del init_params['params']
info['_init_params'] = init_params
if get_raw_params:
info['raw_params'] = self._raw_params
try:
primary_metric_key = list(self.eval_metrics.keys())[0]
primary_metric_value = float(self.eval_metrics[primary_metric_key])
info['_Attributes']['eval_metrics'] = {
primary_metric_key: primary_metric_value
}
except:
pass
if hasattr(self, 'test_transforms'):
if self.test_transforms is not None:
info['Transforms'] = list()
for op in self.test_transforms.transforms:
name = op.__class__.__name__
attr = op.__dict__
info['Transforms'].append({name: attr})
arrange = self.test_transforms.arrange
if arrange is not None:
info['Transforms'].append({
arrange.__class__.__name__: {
'mode': 'test'
}
})
info['completed_epochs'] = self.completed_epochs
return info
def get_pruning_info(self):
info = dict()
info['pruner'] = self.pruner.__class__.__name__
info['pruning_ratios'] = self.pruning_ratios
pruner_inputs = self.pruner.inputs
if self.model_type == 'detector':
pruner_inputs = {k: v.tolist() for k, v in pruner_inputs[0].items()}
info['pruner_inputs'] = pruner_inputs
return info
def get_quant_info(self):
info = dict()
info['quant_config'] = self.quant_config
return info
def save_model(self, save_dir):
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
model_info = self.get_model_info(get_raw_params=True)
model_info['status'] = self.status
paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
paddle.save(self.optimizer.state_dict(),
osp.join(save_dir, 'model.pdopt'))
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(model_info, f)
# Save evaluation details
if hasattr(self, 'eval_details'):
with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
json.dump(self.eval_details, f)
if self.status == 'Pruned' and self.pruner is not None:
pruning_info = self.get_pruning_info()
with open(
osp.join(save_dir, 'prune.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(pruning_info, f)
if self.status == 'Quantized' and self.quantizer is not None:
quant_info = self.get_quant_info()
with open(
osp.join(save_dir, 'quant.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(quant_info, f)
# Success flag
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model saved in {}.".format(save_dir))
def build_data_loader(self, dataset, batch_size, mode='train'):
if dataset.num_samples < batch_size:
raise ValueError(
'The volume of dataset({}) must be larger than batch size({}).'
.format(dataset.num_samples, batch_size))
batch_size_each_card = get_single_card_bs(batch_size=batch_size)
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size_each_card,
shuffle=dataset.shuffle,
drop_last=mode == 'train')
if dataset.num_workers > 0:
shm_size = _get_shared_memory_size_in_M()
if shm_size is None or shm_size < 1024.:
use_shared_memory = False
else:
use_shared_memory = True
else:
use_shared_memory = False
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=dataset.batch_transforms,
num_workers=dataset.num_workers,
return_list=True,
use_shared_memory=use_shared_memory)
return loader
def train_loop(self,
num_epochs,
train_dataset,
train_batch_size,
eval_dataset=None,
save_interval_epochs=1,
log_interval_steps=10,
save_dir='output',
ema=None,
early_stop=False,
early_stop_patience=5,
use_vdl=True):
self._check_transforms(train_dataset.transforms, 'train')
if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
train_dataset.file_list):
nranks = 1
else:
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
if nranks > 1:
find_unused_parameters = getattr(self, 'find_unused_parameters',
False)
# Initialize parallel environment if not done.
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
ddp_net = to_data_parallel(
self.net, find_unused_parameters=find_unused_parameters)
else:
ddp_net = to_data_parallel(
self.net, find_unused_parameters=find_unused_parameters)
if use_vdl:
from visualdl import LogWriter
vdl_logdir = osp.join(save_dir, 'vdl_log')
log_writer = LogWriter(vdl_logdir)
# task_id: refer to paddlers
task_id = getattr(paddlers, "task_id", "")
thresh = .0001
if early_stop:
earlystop = EarlyStop(early_stop_patience, thresh)
self.train_data_loader = self.build_data_loader(
train_dataset, batch_size=train_batch_size, mode='train')
if eval_dataset is not None:
self.test_transforms = copy.deepcopy(eval_dataset.transforms)
start_epoch = self.completed_epochs
train_step_time = SmoothedValue(log_interval_steps)
train_step_each_epoch = math.floor(train_dataset.num_samples /
train_batch_size)
train_total_step = train_step_each_epoch * (num_epochs - start_epoch)
if eval_dataset is not None:
eval_batch_size = train_batch_size
eval_epoch_time = 0
current_step = 0
for i in range(start_epoch, num_epochs):
self.net.train()
if callable(
getattr(self.train_data_loader.dataset, 'set_epoch', None)):
self.train_data_loader.dataset.set_epoch(i)
train_avg_metrics = TrainingStats()
step_time_tic = time.time()
for step, data in enumerate(self.train_data_loader()):
if nranks > 1:
outputs = self.train_step(step, data, ddp_net)
else:
outputs = self.train_step(step, data, self.net)
scheduler_step(self.optimizer, outputs['loss'])
train_avg_metrics.update(outputs)
lr = self.optimizer.get_lr()
outputs['lr'] = lr
if ema is not None:
ema.update(self.net)
step_time_toc = time.time()
train_step_time.update(step_time_toc - step_time_tic)
step_time_tic = step_time_toc
current_step += 1
# Log loss info every log_interval_steps
if current_step % log_interval_steps == 0 and local_rank == 0:
if use_vdl:
for k, v in outputs.items():
log_writer.add_scalar(
'{}-Metrics/Training(Step): {}'.format(
task_id, k), v, current_step)
# Estimation remaining time
avg_step_time = train_step_time.avg()
eta = avg_step_time * (train_total_step - current_step)
if eval_dataset is not None:
eval_num_epochs = math.ceil(
(num_epochs - i - 1) / save_interval_epochs)
if eval_epoch_time == 0:
eta += avg_step_time * math.ceil(
eval_dataset.num_samples / eval_batch_size)
else:
eta += eval_epoch_time * eval_num_epochs
logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
.format(i + 1, num_epochs, step + 1,
train_step_each_epoch,
dict2str(outputs),
round(avg_step_time, 2), seconds_to_hms(eta)))
logging.info('[TRAIN] Epoch {} finished, {} .'
.format(i + 1, train_avg_metrics.log()))
self.completed_epochs += 1
if ema is not None:
weight = copy.deepcopy(self.net.state_dict())
self.net.set_state_dict(ema.apply())
eval_epoch_tic = time.time()
# Every save_interval_epochs, evaluate and save the model
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
if eval_dataset is not None and eval_dataset.num_samples > 0:
eval_result = self.evaluate(
eval_dataset,
batch_size=eval_batch_size,
return_details=True)
# Save the optimial model
if local_rank == 0:
self.eval_metrics, self.eval_details = eval_result
if use_vdl:
for k, v in self.eval_metrics.items():
try:
log_writer.add_scalar(
'{}-Metrics/Eval(Epoch): {}'.format(
task_id, k), v, i + 1)
except TypeError:
pass
logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
i + 1, dict2str(self.eval_metrics)))
best_accuracy_key = list(self.eval_metrics.keys())[0]
current_accuracy = self.eval_metrics[best_accuracy_key]
if current_accuracy > self.best_accuracy:
self.best_accuracy = current_accuracy
self.best_model_epoch = i + 1
best_model_dir = osp.join(save_dir, "best_model")
self.save_model(save_dir=best_model_dir)
if self.best_model_epoch > 0:
logging.info(
'Current evaluated best model on eval_dataset is epoch_{}, {}={}'
.format(self.best_model_epoch,
best_accuracy_key, self.best_accuracy))
eval_epoch_time = time.time() - eval_epoch_tic
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if local_rank == 0:
self.save_model(save_dir=current_save_dir)
if eval_dataset is not None and early_stop:
if earlystop(current_accuracy):
break
if ema is not None:
self.net.set_state_dict(weight)
def analyze_sensitivity(self,
dataset,
batch_size=8,
criterion='l1_norm',
save_dir='output'):
pass
def prune(self, pruned_flops, save_dir=None):
pass
def _prepare_qat(self, quant_config):
pass
def _get_pipeline_info(self, save_dir):
pipeline_info = {}
pipeline_info["pipeline_name"] = self.model_type
nodes = [{
"src0": {
"type": "Source",
"next": "decode0"
}
}, {
"decode0": {
"type": "Decode",
"next": "predict0"
}
}, {
"predict0": {
"type": "Predict",
"init_params": {
"use_gpu": False,
"gpu_id": 0,
"use_trt": False,
"model_dir": save_dir,
},
"next": "sink0"
}
}, {
"sink0": {
"type": "Sink"
}
}]
pipeline_info["pipeline_nodes"] = nodes
pipeline_info["version"] = "1.0.0"
return pipeline_info
def _build_inference_net(self):
raise NotImplementedError
def _export_inference_model(self, save_dir, image_shape=None):
self.test_inputs = self._get_test_inputs(image_shape)
infer_net = self._build_inference_net()
if self.status == 'Quantized':
self.quantizer.save_quantized_model(infer_net,
osp.join(save_dir, 'model'),
self.test_inputs)
quant_info = self.get_quant_info()
with open(
osp.join(save_dir, 'quant.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(quant_info, f)
else:
static_net = paddle.jit.to_static(
infer_net, input_spec=self.test_inputs)
paddle.jit.save(static_net, osp.join(save_dir, 'model'))
if self.status == 'Pruned':
pruning_info = self.get_pruning_info()
with open(
osp.join(save_dir, 'prune.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(pruning_info, f)
model_info = self.get_model_info()
model_info['status'] = 'Infer'
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(model_info, f)
pipeline_info = self._get_pipeline_info(save_dir)
with open(
osp.join(save_dir, 'pipeline.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(pipeline_info, f)
# Success flag
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("The inference model for deployment is saved in {}.".
format(save_dir))
def train_step(self, step, data, net):
outputs = self.run(net, data, mode='train')
loss = outputs['loss']
loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()
return outputs
def _check_transforms(self, transforms, mode):
# NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
if not isinstance(transforms, paddlers.transforms.Compose):
raise TypeError("`transforms` must be paddlers_slim.transforms.Compose.")
arrange_obj = transforms.arrange
if not isinstance(arrange_obj, paddlers.transforms.operators.Arrange):
raise TypeError("`transforms.arrange` must be an Arrange object.")
if arrange_obj.mode != mode:
raise ValueError(
f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
)
def run(self, net, inputs, mode):
raise NotImplementedError
def train(self, *args, **kwargs):
raise NotImplementedError
def evaluate(self, *args, **kwargs):
raise NotImplementedError
def preprocess(self, images, transforms, to_tensor):
raise NotImplementedError
def postprocess(self, *args, **kwargs):
raise NotImplementedError