You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
4.8 KiB
147 lines
4.8 KiB
#!/usr/bin/env python |
|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
# Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py |
|
|
|
import argparse |
|
import os |
|
import os.path as osp |
|
import sys |
|
|
|
import paddle |
|
import numpy as np |
|
import paddlers |
|
from paddle.hapi.dynamic_flops import (count_parameters, register_hooks, |
|
count_io_info) |
|
from paddle.hapi.static_flops import Table |
|
|
|
_dir = osp.dirname(osp.abspath(__file__)) |
|
sys.path.append(osp.abspath(osp.join(_dir, '../'))) |
|
import bootstrap |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_dir", default=None, type=str, help="Path of saved model.") |
|
parser.add_argument( |
|
"--input_shape", |
|
nargs='+', |
|
type=int, |
|
default=[1, 3, 256, 256], |
|
help="Shape of each input tensor.") |
|
return parser.parse_args() |
|
|
|
|
|
def analyze(model, inputs, custom_ops=None, print_detail=False): |
|
handler_collection = [] |
|
types_collection = set() |
|
if custom_ops is None: |
|
custom_ops = {} |
|
|
|
def add_hooks(m): |
|
if len(list(m.children())) > 0: |
|
return |
|
m.register_buffer('total_ops', paddle.zeros([1], dtype='int64')) |
|
m.register_buffer('total_params', paddle.zeros([1], dtype='int64')) |
|
m_type = type(m) |
|
|
|
flops_fn = None |
|
if m_type in custom_ops: |
|
flops_fn = custom_ops[m_type] |
|
if m_type not in types_collection: |
|
print("Customized function has been applied to {}".format( |
|
m_type)) |
|
elif m_type in register_hooks: |
|
flops_fn = register_hooks[m_type] |
|
if m_type not in types_collection: |
|
print("{}'s FLOPs metric has been counted".format(m_type)) |
|
else: |
|
if m_type not in types_collection: |
|
print( |
|
"Cannot find suitable counting function for {}. Treat it as zero FLOPs." |
|
.format(m_type)) |
|
|
|
if flops_fn is not None: |
|
flops_handler = m.register_forward_post_hook(flops_fn) |
|
handler_collection.append(flops_handler) |
|
params_handler = m.register_forward_post_hook(count_parameters) |
|
io_handler = m.register_forward_post_hook(count_io_info) |
|
handler_collection.append(params_handler) |
|
handler_collection.append(io_handler) |
|
types_collection.add(m_type) |
|
|
|
training = model.training |
|
|
|
model.eval() |
|
model.apply(add_hooks) |
|
|
|
with paddle.framework.no_grad(): |
|
model(*inputs) |
|
|
|
total_ops = 0 |
|
total_params = 0 |
|
for m in model.sublayers(): |
|
if len(list(m.children())) > 0: |
|
continue |
|
if set(['total_ops', 'total_params', 'input_shape', |
|
'output_shape']).issubset(set(list(m._buffers.keys()))): |
|
total_ops += m.total_ops |
|
total_params += m.total_params |
|
|
|
if training: |
|
model.train() |
|
for handler in handler_collection: |
|
handler.remove() |
|
|
|
table = Table( |
|
["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"]) |
|
|
|
for n, m in model.named_sublayers(): |
|
if len(list(m.children())) > 0: |
|
continue |
|
if set(['total_ops', 'total_params', 'input_shape', |
|
'output_shape']).issubset(set(list(m._buffers.keys()))): |
|
table.add_row([ |
|
m.full_name(), list(m.input_shape.numpy()), |
|
list(m.output_shape.numpy()), |
|
round(float(m.total_params / 1e6), 3), |
|
round(float(m.total_ops / 1e9), 3) |
|
]) |
|
m._buffers.pop("total_ops") |
|
m._buffers.pop("total_params") |
|
m._buffers.pop('input_shape') |
|
m._buffers.pop('output_shape') |
|
if print_detail: |
|
table.print_table() |
|
print('Total FLOPs: {}G Total Params: {}M'.format( |
|
round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3))) |
|
return int(total_ops) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
|
|
# Enforce the use of CPU |
|
paddle.set_device('cpu') |
|
|
|
model = paddlers.tasks.load_model(args.model_dir) |
|
net = model.net |
|
|
|
# Construct bi-temporal inputs |
|
inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)] |
|
|
|
analyze(model.net, inputs)
|
|
|