|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import os
|
|
|
|
import argparse
|
|
|
|
from ast import literal_eval
|
|
|
|
|
|
|
|
from paddlers.tasks import load_model
|
|
|
|
|
|
|
|
|
|
|
|
def get_parser():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
'--model_dir',
|
|
|
|
'-m',
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help='model directory path')
|
|
|
|
parser.add_argument(
|
|
|
|
'--save_dir',
|
|
|
|
'-s',
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help='path to save inference model')
|
|
|
|
parser.add_argument(
|
|
|
|
'--fixed_input_shape',
|
|
|
|
'-fs',
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = get_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# Get input shape
|
|
|
|
fixed_input_shape = None
|
|
|
|
if args.fixed_input_shape is not None:
|
|
|
|
# Try to interpret the string as a list.
|
|
|
|
fixed_input_shape = literal_eval(args.fixed_input_shape)
|
|
|
|
# Check validaty
|
|
|
|
if not isinstance(fixed_input_shape, list):
|
|
|
|
raise ValueError(
|
|
|
|
"fixed_input_shape should be of None or list type.")
|
|
|
|
if len(fixed_input_shape) not in (2, 4):
|
|
|
|
raise ValueError(
|
|
|
|
"fixed_input_shape contains an incorrect number of elements.")
|
|
|
|
if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
|
|
|
|
raise ValueError(
|
|
|
|
"Input width and height must be positive integers.")
|
|
|
|
if len(fixed_input_shape) == 4 and fixed_input_shape[1] <= 0:
|
|
|
|
raise ValueError(
|
|
|
|
"The number of input channels must be a positive integer.")
|
|
|
|
|
|
|
|
# Set environment variables
|
|
|
|
os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
|
|
|
|
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
|
|
|
|
|
|
|
|
# Load model from directory
|
|
|
|
model = load_model(args.model_dir)
|
|
|
|
|
|
|
|
# Do dynamic-to-static cast
|
|
|
|
# XXX: Invoke a protected (single underscore) method outside of subclasses.
|
|
|
|
model._export_inference_model(args.save_dir, fixed_input_shape)
|