You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

109 lines
3.7 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import yaml
import numpy as np
import paddle
from .. import tasks
from ..transforms import build_transforms
def load_rcnn_inference_model(model_dir):
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
path_prefix = osp.join(model_dir, "model")
prog, _, _ = paddle.static.load_inference_model(path_prefix, exe)
paddle.disable_static()
extra_var_info = paddle.load(osp.join(model_dir, "model.pdiparams.info"))
net_state_dict = dict()
static_state_dict = dict()
for name, var in prog.state_dict().items():
static_state_dict[name] = np.array(var)
for var_name in static_state_dict:
if var_name not in extra_var_info:
continue
structured_name = extra_var_info[var_name].get('structured_name', None)
if structured_name is None:
continue
net_state_dict[structured_name] = static_state_dict[var_name]
return net_state_dict
def load_model(model_dir, **params):
"""
Load saved model from a given directory.
Args:
model_dir(str): Directory where the model is saved.
Returns:
The model loaded from the directory.
"""
if not osp.exists(model_dir):
print("Directory '{}' does not exist!".format(model_dir))
if not osp.exists(osp.join(model_dir, "model.yml")):
raise FileNotFoundError(
"There is no file named model.yml in {}.".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f:
model_info = yaml.load(f.read(), Loader=yaml.Loader)
status = model_info['status']
with_net = params.get('with_net', True)
if not with_net:
assert status == 'Infer', \
"Only exported models can be deployed for inference, but current model status is {}.".format(status)
model_type = model_info['_Attributes']['model_type']
mod = getattr(tasks, model_type)
if not hasattr(mod, model_info['Model']):
raise ValueError("There is no {} attribute in {}.".format(model_info[
'Model'], mod))
if 'model_name' in model_info['_init_params']:
del model_info['_init_params']['model_name']
model_info['_init_params'].update({'with_net': with_net})
with paddle.utils.unique_name.guard():
if 'raw_params' not in model_info:
print(
"Cannot find raw_params. Default arguments will be used to construct the model."
)
params = model_info.pop('raw_params', {})
params.update(model_info['_init_params'])
model = getattr(mod, model_info['Model'])(**params)
if with_net:
net_state_dict = paddle.load(
osp.join(model_dir, 'model.pdparams'))
model.net.set_state_dict(net_state_dict)
if 'Transforms' in model_info:
model.test_transforms = build_transforms(model_info['Transforms'])
if '_Attributes' in model_info:
for k, v in model_info['_Attributes'].items():
if k in model.__dict__:
model.__dict__[k] = v
print("Model[{}] loaded.".format(model_info['Model']))
model.status = status
return model