You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

147 lines
4.8 KiB

#!/usr/bin/env python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py
import argparse
import os
import os.path as osp
import sys
import paddle
import numpy as np
import paddlers
from paddle.hapi.dynamic_flops import (count_parameters, register_hooks,
count_io_info)
from paddle.hapi.static_flops import Table
_dir = osp.dirname(osp.abspath(__file__))
sys.path.append(osp.abspath(osp.join(_dir, '../')))
import bootstrap
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir", default=None, type=str, help="Path of saved model.")
parser.add_argument(
"--input_shape",
nargs='+',
type=int,
default=[1, 3, 256, 256],
help="Shape of each input tensor.")
return parser.parse_args()
def analyze(model, inputs, custom_ops=None, print_detail=False):
handler_collection = []
types_collection = set()
if custom_ops is None:
custom_ops = {}
def add_hooks(m):
if len(list(m.children())) > 0:
return
m.register_buffer('total_ops', paddle.zeros([1], dtype='int64'))
m.register_buffer('total_params', paddle.zeros([1], dtype='int64'))
m_type = type(m)
flops_fn = None
if m_type in custom_ops:
flops_fn = custom_ops[m_type]
if m_type not in types_collection:
print("Customized function has been applied to {}".format(
m_type))
elif m_type in register_hooks:
flops_fn = register_hooks[m_type]
if m_type not in types_collection:
print("{}'s FLOPs metric has been counted".format(m_type))
else:
if m_type not in types_collection:
print(
"Cannot find suitable counting function for {}. Treat it as zero FLOPs."
.format(m_type))
if flops_fn is not None:
flops_handler = m.register_forward_post_hook(flops_fn)
handler_collection.append(flops_handler)
params_handler = m.register_forward_post_hook(count_parameters)
io_handler = m.register_forward_post_hook(count_io_info)
handler_collection.append(params_handler)
handler_collection.append(io_handler)
types_collection.add(m_type)
training = model.training
model.eval()
model.apply(add_hooks)
with paddle.framework.no_grad():
model(*inputs)
total_ops = 0
total_params = 0
for m in model.sublayers():
if len(list(m.children())) > 0:
continue
if set(['total_ops', 'total_params', 'input_shape',
'output_shape']).issubset(set(list(m._buffers.keys()))):
total_ops += m.total_ops
total_params += m.total_params
if training:
model.train()
for handler in handler_collection:
handler.remove()
table = Table(
["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"])
for n, m in model.named_sublayers():
if len(list(m.children())) > 0:
continue
if set(['total_ops', 'total_params', 'input_shape',
'output_shape']).issubset(set(list(m._buffers.keys()))):
table.add_row([
m.full_name(), list(m.input_shape.numpy()),
list(m.output_shape.numpy()),
round(float(m.total_params / 1e6), 3),
round(float(m.total_ops / 1e9), 3)
])
m._buffers.pop("total_ops")
m._buffers.pop("total_params")
m._buffers.pop('input_shape')
m._buffers.pop('output_shape')
if print_detail:
table.print_table()
print('Total FLOPs: {}G Total Params: {}M'.format(
round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3)))
return int(total_ops)
if __name__ == '__main__':
args = parse_args()
# Enforce the use of CPU
paddle.set_device('cpu')
model = paddlers.tasks.load_model(args.model_dir)
net = model.net
# Construct bi-temporal inputs
inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)]
analyze(model.net, inputs)