You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
2.0 KiB
54 lines
2.0 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import paddle |
|
import paddleslim |
|
|
|
FILTER_DIM = paddleslim.dygraph.prune.filter_pruner.FILTER_DIM |
|
|
|
|
|
def _pruner_eval_fn(model, eval_dataset, batch_size): |
|
metric = model.evaluate(eval_dataset, batch_size=batch_size) |
|
return metric[list(metric.keys())[0]] |
|
|
|
|
|
def _pruner_template_input(sample, model_type): |
|
if model_type == 'detector': |
|
template_input = [{ |
|
"image": paddle.ones( |
|
shape=[1, 3] + list(sample["image"].shape[:2]), |
|
dtype='float32'), |
|
"im_shape": paddle.full( |
|
[1, 2], 640, dtype='float32'), |
|
"scale_factor": paddle.ones( |
|
shape=[1, 2], dtype='float32') |
|
}] |
|
else: |
|
template_input = [1] + list(sample[0].shape) |
|
|
|
return template_input |
|
|
|
|
|
def sensitive_prune(pruner, pruned_flops, skip_vars=[], align=None): |
|
# Skip depthwise convolutions |
|
for layer in pruner.model.sublayers(): |
|
if isinstance(layer, paddle.nn.layer.conv.Conv2D) and layer._groups > 1: |
|
for param in layer.parameters(include_sublayers=False): |
|
skip_vars.append(param.name) |
|
pruner.restore() |
|
ratios, pruned_flops = pruner.get_ratios_by_sensitivity( |
|
pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars) |
|
pruner.plan = pruner.prune_vars(ratios, FILTER_DIM) |
|
pruner.plan._pruned_flops = pruned_flops |
|
return pruner.plan, ratios
|
|
|