fix conflicts

own
michaelowenliu 3 years ago
commit 104b6150f8
  1. 61
      paddlers/tasks/base.py
  2. 34
      paddlers/tasks/load_model.py
  3. 4
      tutorials/train/semantic_segmentation/deeplabv3p.py
  4. 60
      tutorials/train/semantic_segmentation/run_with_clean_log.py
  5. 4
      tutorials/train/semantic_segmentation/unet.py

@ -14,13 +14,14 @@
import os
import os.path as osp
from functools import partial
import time
import copy
import math
import yaml
import json
from functools import partial, wraps
from inspect import signature
import yaml
import paddle
from paddle.io import DataLoader, DistributedBatchSampler
from paddleslim import QAT
@ -28,17 +29,40 @@ from paddleslim.analysis import flops
from paddleslim import L1NormFilterPruner, FPGMFilterPruner
import paddlers
import paddlers.utils.logging as logging
from paddlers.transforms import arrange_transforms
from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
get_pretrain_weights, load_pretrain_weights,
load_checkpoint, SmoothedValue, TrainingStats,
_get_shared_memory_size_in_M, EarlyStop)
import paddlers.utils.logging as logging
from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
from .utils.infer_nets import InferNet, InferCDNet
class BaseModel:
class ModelMeta(type):
def __new__(cls, name, bases, attrs):
def _deco(init_func):
@wraps(init_func)
def _wrapper(self, *args, **kwargs):
if hasattr(self, '_raw_params'):
ret = init_func(self, *args, **kwargs)
else:
sig = signature(init_func)
bnd_args = sig.bind(self, *args, **kwargs)
raw_params = bnd_args.arguments
raw_params.pop('self')
self._raw_params = raw_params
ret = init_func(self, *args, **kwargs)
return ret
return _wrapper
old_init_func = attrs['__init__']
attrs['__init__'] = _deco(old_init_func)
return type.__new__(cls, name, bases, attrs)
class BaseModel(metaclass=ModelMeta):
def __init__(self, model_type):
self.model_type = model_type
self.in_channels = None
@ -128,7 +152,11 @@ class BaseModel:
model_name=self.model_name,
checkpoint=resume_checkpoint)
def get_model_info(self):
def get_model_info(self, get_raw_params=False, inplace=True):
if inplace:
init_params = self.init_params
else:
init_params = copy.deepcopy(self.init_params)
info = dict()
info['version'] = paddlers.__version__
info['Model'] = self.__class__.__name__
@ -138,16 +166,19 @@ class BaseModel:
('fixed_input_shape', self.fixed_input_shape),
('best_accuracy', self.best_accuracy),
('best_model_epoch', self.best_model_epoch)])
if 'self' in self.init_params:
del self.init_params['self']
if '__class__' in self.init_params:
del self.init_params['__class__']
if 'model_name' in self.init_params:
del self.init_params['model_name']
if 'params' in self.init_params:
del self.init_params['params']
if 'self' in init_params:
del init_params['self']
if '__class__' in init_params:
del init_params['__class__']
if 'model_name' in init_params:
del init_params['model_name']
if 'params' in init_params:
del init_params['params']
info['_init_params'] = init_params
info['_init_params'] = self.init_params
if get_raw_params:
info['raw_params'] = self._raw_params
try:
primary_metric_key = list(self.eval_metrics.keys())[0]
@ -191,7 +222,7 @@ class BaseModel:
if osp.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
model_info = self.get_model_info()
model_info = self.get_model_info(get_raw_params=True)
model_info['status'] = self.status
paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))

@ -50,27 +50,31 @@ def load_rcnn_inference_model(model_dir):
def load_model(model_dir, **params):
"""
Load saved model from a given directory.
Args:
model_dir(str): The directory where the model is saved.
Returns:
The model loaded from the directory.
"""
if not osp.exists(model_dir):
logging.error("model_dir '{}' does not exists!".format(model_dir))
logging.error("Directory '{}' does not exist!".format(model_dir))
if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's no model.yml in {}".format(model_dir))
raise Exception("There is no file named model.yml in {}.".format(
model_dir))
with open(osp.join(model_dir, "model.yml")) as f:
model_info = yaml.load(f.read(), Loader=yaml.Loader)
f.close()
status = model_info['status']
with_net = params.get('with_net', True)
if not with_net:
assert status == 'Infer', \
"Only exported inference models can be deployed, current model status is {}".format(status)
"Only exported models can be deployed for inference, but current model status is {}.".format(status)
if not hasattr(paddlers.tasks, model_info['Model']):
raise Exception("There's no attribute {} in paddlers.tasks".format(
raise Exception("There is no {} attribute in paddlers.tasks.".format(
model_info['Model']))
if 'model_name' in model_info['_init_params']:
del model_info['_init_params']['model_name']
@ -78,8 +82,9 @@ def load_model(model_dir, **params):
model_info['_init_params'].update({'with_net': with_net})
with paddle.utils.unique_name.guard():
model = getattr(paddlers.tasks, model_info['Model'])(
**model_info['_init_params'])
params = model_info.pop('raw_params', {})
params.update(model_info['_init_params'])
model = getattr(paddlers.tasks, model_info['Model'])(**params)
if with_net:
if status == 'Pruned' or osp.exists(
osp.join(model_dir, "prune.yml")):
@ -110,18 +115,19 @@ def load_model(model_dir, **params):
if status == 'Infer':
if osp.exists(osp.join(model_dir, "quant.yml")):
logging.error(
"Exported quantized model can not be loaded, only deployment is supported.",
"Exported quantized model can not be loaded, because quant.yml is not found.",
exit=True)
model.net = model._build_inference_net()
if model_info['Model'] in ['FasterRCNN', 'MaskRCNN']:
net_state_dict = load_rcnn_inference_model(model_dir)
else:
net_state_dict = paddle.load(osp.join(model_dir, 'model'))
if model.model_type in ['classifier', 'segmenter'
] and 'rc' in version:
# When exporting a classifier and segmenter,
# InferNet is defined to append softmax and argmax operators to the model,
# so parameter name starts with 'net.'
if model.model_type in [
'classifier', 'segmenter', 'changedetector'
]:
# When exporting a classifier, segmenter, or changedetector,
# InferNet (or InferCDNet) is defined to append softmax and argmax operators to the model,
# so the parameter names all start with 'net.'
new_net_state_dict = {}
for k, v in net_state_dict.items():
new_net_state_dict['net.' + k] = v
@ -139,6 +145,8 @@ def load_model(model_dir, **params):
for k, v in model_info['_Attributes'].items():
if k in model.__dict__:
model.__dict__[k] = v
logging.info("Model[{}] loaded.".format(model_info['Model']))
model.status = status
return model

@ -79,10 +79,10 @@ model.train(
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=50,
log_interval_steps=4,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.01,
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能

@ -0,0 +1,60 @@
#!/usr/bin/env python
import sys
import subprocess
from io import RawIOBase
class StreamFilter(RawIOBase):
def __init__(self, conds, stream):
super().__init__()
self.conds = conds
self.stream = stream
def readinto(self, _):
pass
def write(self, msg):
if all(cond(msg) for cond in self.conds):
self.stream.write(msg)
else:
pass
class CleanLog(object):
def __init__(self, filter_, stream_name):
self.filter = filter_
self.stream_name = stream_name
self.old_stream = getattr(sys, stream_name)
def __enter__(self):
setattr(sys, self.stream_name, self.filter)
def __exit__(self, exc_type, exc_value, traceback):
setattr(sys, self.stream_name, self.old_stream)
if __name__ == '__main__':
if len(sys.argv) < 2:
raise TypeError("请指定需要运行的脚本!")
tar_file = sys.argv[1]
gdal_filter = StreamFilter([
lambda msg: "Sum of Photometric type-related color channels and ExtraSamples doesn't match SamplesPerPixel." not in msg
], sys.stdout)
with CleanLog(gdal_filter, 'stdout'):
proc = subprocess.Popen(
["python", tar_file],
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
text=True)
while True:
try:
out_line = proc.stdout.readline()
if out_line == '' and proc.poll() is not None:
break
if out_line:
print(out_line, end='')
except KeyboardInterrupt:
import signal
proc.send_signal(signal.SIGINT)

@ -77,10 +77,10 @@ model.train(
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=50,
log_interval_steps=4,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.01,
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能

Loading…
Cancel
Save