You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

54 lines
2.0 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddleslim
FILTER_DIM = paddleslim.dygraph.prune.filter_pruner.FILTER_DIM
def _pruner_eval_fn(model, eval_dataset, batch_size):
metric = model.evaluate(eval_dataset, batch_size=batch_size)
return metric[list(metric.keys())[0]]
def _pruner_template_input(sample, model_type):
if model_type == 'detector':
template_input = [{
"image": paddle.ones(
shape=[1, 3] + list(sample["image"].shape[:2]),
dtype='float32'),
"im_shape": paddle.full(
[1, 2], 640, dtype='float32'),
"scale_factor": paddle.ones(
shape=[1, 2], dtype='float32')
}]
else:
template_input = [1] + list(sample[0].shape)
return template_input
def sensitive_prune(pruner, pruned_flops, skip_vars=[], align=None):
# Skip depthwise convolutions
for layer in pruner.model.sublayers():
if isinstance(layer, paddle.nn.layer.conv.Conv2D) and layer._groups > 1:
for param in layer.parameters(include_sublayers=False):
skip_vars.append(param.name)
pruner.restore()
ratios, pruned_flops = pruner.get_ratios_by_sensitivity(
pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars)
pruner.plan = pruner.prune_vars(ratios, FILTER_DIM)
pruner.plan._pruned_flops = pruned_flops
return pruner.plan, ratios