You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

742 lines
30 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import copy
import time
import numpy as np
import typing
from PIL import Image, ImageOps
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp
from paddle.static import InputSpec
from paddlers.models.ppdet.optimizer import ModelEMA
from paddlers.models.ppdet.core.workspace import create
from paddlers.models.ppdet.modeling.architectures.meta_arch import BaseArch
from paddlers.models.ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from paddlers.models.ppdet.utils.visualizer import visualize_results, save_result
from paddlers.models.ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
from paddlers.models.ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from paddlers.models.ppdet.data.source.sniper_coco import SniperCOCODataSet
from paddlers.models.ppdet.data.source.category import get_categories
from paddlers.models.ppdet.utils import stats
from paddlers.models.ppdet.utils import profiler
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
from .export_utils import _dump_infer_config, _prune_input_spec
from paddlers.models.ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
__all__ = ['Trainer']
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
class Trainer(object):
def __init__(self, cfg, mode='train'):
self.cfg = cfg
assert mode.lower() in ['train', 'eval', 'test'], \
"mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower()
self.optimizer = None
self.is_loaded_weights = False
# build data loader
if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
else:
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
if cfg.architecture == 'DeepSORT' and self.mode == 'train':
logger.error('DeepSORT has no need of training on mot dataset.')
sys.exit(1)
if self.mode == 'train':
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num)
if cfg.architecture == 'JDE' and self.mode == 'train':
cfg['JDEEmbeddingHead'][
'num_identities'] = self.dataset.num_identities_dict[0]
# JDE only support single class MOT now.
if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][
'num_identities_dict'] = self.dataset.num_identities_dict
# FairMOT support single class and multi-class MOT now.
# build model
if 'model' not in self.cfg:
self.model = create(cfg.architecture)
else:
self.model = self.cfg.model
self.is_loaded_weights = True
#normalize params for deploy
self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1)
self.ema = ModelEMA(
self.model,
decay=ema_decay,
use_thres_step=True,
cycle_epoch=cycle_epoch)
# EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate
if self.mode == 'eval':
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
reader_name = '{}Reader'.format(self.mode.capitalize())
# If metric is VOC, need to be set collate_batch=False.
if cfg.metric == 'VOC':
cfg[reader_name]['collate_batch'] = False
self.loader = create(reader_name)(self.dataset, cfg.worker_num,
self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
# build optimizer in train mode
if self.mode == 'train':
steps_per_epoch = len(self.loader)
self.lr = create('LearningRate')(steps_per_epoch)
self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
if self.cfg.get('unstructured_prune'):
self.pruner = create('UnstructuredPruner')(self.model,
steps_per_epoch)
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
self.status = {}
self.start_epoch = 0
self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
# initial default callbacks
self._init_callbacks()
# initial default metrics
self._init_metrics()
self._reset_metrics()
def _init_callbacks(self):
if self.mode == 'train':
self._callbacks = [LogPrinter(self), Checkpointer(self)]
if self.cfg.get('use_vdl', False):
self._callbacks.append(VisualDLWriter(self))
if self.cfg.get('save_proposals', False):
self._callbacks.append(SniperProposalsGenerator(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval':
self._callbacks = [LogPrinter(self)]
if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and self.cfg.get('use_vdl', False):
self._callbacks = [VisualDLWriter(self)]
self._compose_callback = ComposeCallback(self._callbacks)
else:
self._callbacks = []
self._compose_callback = None
def _init_metrics(self, validate=False):
if self.mode == 'test' or (self.mode == 'train' and not validate):
self._metrics = []
return
classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
# TODO: bias should be unified
bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg.get('save_prediction_only', False)
# pass clsid2catid info to metric instance to avoid multiple loading
# annotation file
clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
if self.mode == 'eval' else None
# when do validation in train, annotation file should be get from
# EvalReader instead of self.dataset(which is TrainReader)
anno_file = self.dataset.get_anno()
dataset = self.dataset
if self.mode == 'train' and validate:
eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno()
dataset = eval_dataset
IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
if self.cfg.metric == "COCO":
self._metrics = [
COCOMetric(
anno_file=anno_file,
clsid2catid=clsid2catid,
classwise=classwise,
output_eval=output_eval,
bias=bias,
IouType=IouType,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == "SNIPERCOCO": # sniper
self._metrics = [
SNIPERCOCOMetric(
anno_file=anno_file,
dataset=dataset,
clsid2catid=clsid2catid,
classwise=classwise,
output_eval=output_eval,
bias=bias,
IouType=IouType,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'RBOX':
# TODO: bias should be unified
bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg.get('save_prediction_only', False)
# pass clsid2catid info to metric instance to avoid multiple loading
# annotation file
clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
if self.mode == 'eval' else None
# when do validation in train, annotation file should be get from
# EvalReader instead of self.dataset(which is TrainReader)
anno_file = self.dataset.get_anno()
if self.mode == 'train' and validate:
eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno()
self._metrics = [
RBoxMetric(
anno_file=anno_file,
clsid2catid=clsid2catid,
classwise=classwise,
output_eval=output_eval,
bias=bias,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'VOC':
self._metrics = [
VOCMetric(
label_list=self.dataset.get_label_list(),
class_num=self.cfg.num_classes,
map_type=self.cfg.map_type,
classwise=classwise)
]
elif self.cfg.metric == 'WiderFace':
multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
self._metrics = [
WiderFaceMetric(
image_dir=os.path.join(self.dataset.dataset_dir,
self.dataset.image_dir),
anno_file=self.dataset.get_anno(),
multi_scale=multi_scale)
]
elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno()
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [
KeyPointTopDownCOCOEval(
anno_file,
len(eval_dataset),
self.cfg.num_joints,
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno()
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [
KeyPointTopDownMPIIEval(
anno_file,
len(eval_dataset),
self.cfg.num_joints,
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'MOTDet':
self._metrics = [JDEDetMetric(), ]
else:
logger.warning("Metric not support for metric type {}".format(
self.cfg.metric))
self._metrics = []
def _reset_metrics(self):
for metric in self._metrics:
metric.reset()
def register_callbacks(self, callbacks):
callbacks = [c for c in list(callbacks) if c is not None]
for c in callbacks:
assert isinstance(c, Callback), \
"metrics shoule be instances of subclass of Metric"
self._callbacks.extend(callbacks)
self._compose_callback = ComposeCallback(self._callbacks)
def register_metrics(self, metrics):
metrics = [m for m in list(metrics) if m is not None]
for m in metrics:
assert isinstance(m, Metric), \
"metrics shoule be instances of subclass of Metric"
self._metrics.extend(metrics)
def load_weights(self, weights):
if self.is_loaded_weights:
return
self.start_epoch = 0
load_pretrain_weight(self.model, weights)
logger.debug("Load weights {} to start training".format(weights))
def load_weights_sde(self, det_weights, reid_weights):
if self.model.detector:
load_weight(self.model.detector, det_weights)
load_weight(self.model.reid, reid_weights)
else:
load_weight(self.model.reid, reid_weights)
def resume_weights(self, weights):
# support Distill resume weights
if hasattr(self.model, 'student_model'):
self.start_epoch = load_weight(self.model.student_model, weights,
self.optimizer)
else:
self.start_epoch = load_weight(self.model, weights, self.optimizer)
logger.debug("Resume weights of epoch {}".format(self.start_epoch))
def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False
sync_bn = (
getattr(self.cfg, 'norm_type', None) in [None, 'sync_bn'] and
self.cfg.use_gpu and self._nranks > 1)
if sync_bn:
self.model = BaseArch.convert_sync_batchnorm(self.model)
model = self.model
if self.cfg.get('fleet', False):
model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
elif self._nranks > 1:
find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
model = paddle.DataParallel(
self.model, find_unused_parameters=find_unused_parameters)
# initial fp16
if self.cfg.get('fp16', False):
scaler = amp.GradScaler(
enable=self.cfg.use_gpu, init_loss_scaling=1024)
self.status.update({
'epoch_id': self.start_epoch,
'step_id': 0,
'steps_per_epoch': len(self.loader)
})
self.status['batch_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['data_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num)
self._flops(flops_loader)
profiler_options = self.cfg.get('profiler_options', None)
self._compose_callback.on_train_begin(self.status)
for epoch_id in range(self.start_epoch, self.cfg.epoch):
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id
self._compose_callback.on_epoch_begin(self.status)
self.loader.dataset.set_epoch(epoch_id)
model.train()
iter_tic = time.time()
for step_id, data in enumerate(self.loader):
self.status['data_time'].update(time.time() - iter_tic)
self.status['step_id'] = step_id
profiler.add_profiler_step(profiler_options)
self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id
if self.cfg.get('fp16', False):
with amp.auto_cast(enable=self.cfg.use_gpu):
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
# in dygraph mode, optimizer.minimize is equal to optimizer.step
scaler.minimize(self.optimizer, scaled_loss)
else:
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
loss.backward()
self.optimizer.step()
curr_lr = self.optimizer.get_lr()
self.lr.step()
if self.cfg.get('unstructured_prune'):
self.pruner.step()
self.optimizer.clear_grad()
self.status['learning_rate'] = curr_lr
if self._nranks < 2 or self._local_rank == 0:
self.status['training_staus'].update(outputs)
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
if self.use_ema:
self.ema.update(self.model)
iter_tic = time.time()
# apply ema weight on model
if self.use_ema:
weight = copy.deepcopy(self.model.state_dict())
self.model.set_dict(self.ema.apply())
if self.cfg.get('unstructured_prune'):
self.pruner.update_params()
self._compose_callback.on_epoch_end(self.status)
if validate and (self._nranks < 2 or self._local_rank == 0) \
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
self._eval_batch_sampler = \
paddle.io.BatchSampler(
self._eval_dataset,
batch_size=self.cfg.EvalReader['batch_size'])
# If metric is VOC, need to be set collate_batch=False.
if self.cfg.metric == 'VOC':
self.cfg['EvalReader']['collate_batch'] = False
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
# if validation in training is enabled, metrics should be re-init
# Init_mark makes sure this code will only execute once
if validate and Init_mark == False:
Init_mark = True
self._init_metrics(validate=validate)
self._reset_metrics()
with paddle.no_grad():
self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader)
# restore origin weight on model
if self.use_ema:
self.model.set_dict(weight)
self._compose_callback.on_train_end(self.status)
def _eval_with_loader(self, loader):
sample_num = 0
tic = time.time()
self._compose_callback.on_epoch_begin(self.status)
self.status['mode'] = 'eval'
self.model.eval()
if self.cfg.get('print_flops', False):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
# forward
outs = self.model(data)
# update metrics
for metric in self._metrics:
metric.update(data, outs)
# multi-scale inputs: all inputs have same im_id
if isinstance(data, typing.Sequence):
sample_num += data[0]['im_id'].numpy().shape[0]
else:
sample_num += data['im_id'].numpy().shape[0]
self._compose_callback.on_step_end(self.status)
self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic
# accumulate metric to log out
for metric in self._metrics:
metric.accumulate()
metric.log()
self._compose_callback.on_epoch_end(self.status)
# reset metric states for metric may performed multiple times
self._reset_metrics()
def evaluate(self):
with paddle.no_grad():
self._eval_with_loader(self.loader)
def predict(self,
images,
draw_threshold=0.5,
output_dir='output',
save_txt=False):
self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0)
imid2path = self.dataset.get_imid2path()
anno_file = self.dataset.get_anno()
clsid2catid, catid2name = get_categories(
self.cfg.metric, anno_file=anno_file)
# Run Infer
self.status['mode'] = 'test'
self.model.eval()
if self.cfg.get('print_flops', False):
flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = []
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
# forward
outs = self.model(data)
for key in ['im_shape', 'scale_factor', 'im_id']:
if isinstance(data, typing.Sequence):
outs[key] = data[0][key]
else:
outs[key] = data[key]
for key, value in outs.items():
if hasattr(value, 'numpy'):
outs[key] = value.numpy()
results.append(outs)
# sniper
if type(self.dataset) == SniperCOCODataSet:
results = self.dataset.anno_cropper.aggregate_chips_detections(
results)
for outs in results:
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
image = ImageOps.exif_transpose(image)
self.status['original_image'] = np.array(image.copy())
end = start + bbox_num[i]
bbox_res = batch_res['bbox'][start:end] \
if 'bbox' in batch_res else None
mask_res = batch_res['mask'][start:end] \
if 'mask' in batch_res else None
segm_res = batch_res['segm'][start:end] \
if 'segm' in batch_res else None
keypoint_res = batch_res['keypoint'][start:end] \
if 'keypoint' in batch_res else None
image = visualize_results(
image, bbox_res, mask_res, segm_res, keypoint_res,
int(im_id), catid2name, draw_threshold)
self.status['result_image'] = np.array(image.copy())
if self._compose_callback:
self._compose_callback.on_step_end(self.status)
# save image with detection
save_name = self._get_save_image_name(output_dir, image_path)
logger.info("Detection bbox results save in {}".format(
save_name))
image.save(save_name, quality=95)
if save_txt:
save_path = os.path.splitext(save_name)[0] + '.txt'
results = {}
results["im_id"] = im_id
if bbox_res:
results["bbox_res"] = bbox_res
if keypoint_res:
results["keypoint_res"] = keypoint_res
save_result(save_path, results, catid2name, draw_threshold)
start = end
def _get_save_image_name(self, output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
image_shape = None
im_shape = [None, 2]
scale_factor = [None, 2]
if self.cfg.architecture in MOT_ARCH:
test_reader_name = 'TestMOTReader'
else:
test_reader_name = 'TestReader'
if 'inputs_def' in self.cfg[test_reader_name]:
inputs_def = self.cfg[test_reader_name]['inputs_def']
image_shape = inputs_def.get('image_shape', None)
# set image_shape=[None, 3, -1, -1] as default
if image_shape is None:
image_shape = [None, 3, -1, -1]
if len(image_shape) == 3:
image_shape = [None] + image_shape
else:
im_shape = [image_shape[0], 2]
scale_factor = [image_shape[0], 2]
if hasattr(self.model, 'deploy'):
self.model.deploy = True
if hasattr(self.model, 'fuse_norm'):
self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
False)
# Save infer cfg
_dump_infer_config(self.cfg,
os.path.join(save_dir, 'infer_cfg.yml'),
image_shape, self.model)
input_spec = [{
"image": InputSpec(
shape=image_shape, name='image'),
"im_shape": InputSpec(
shape=im_shape, name='im_shape'),
"scale_factor": InputSpec(
shape=scale_factor, name='scale_factor')
}]
if self.cfg.architecture == 'DeepSORT':
input_spec[0].update({
"crops": InputSpec(
shape=[None, 3, 192, 64], name='crops')
})
if prune_input:
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = _prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
else:
static_model = None
pruned_input_spec = input_spec
# TODO: Hard code, delete it when support prune input_spec.
if self.cfg.architecture == 'PicoDet':
pruned_input_spec = [{
"image": InputSpec(
shape=image_shape, name='image')
}]
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir)
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
else:
self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
def post_quant(self, output_dir='output_inference'):
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for idx, data in enumerate(self.loader):
self.model(data)
if idx == int(self.cfg.get('quant_batch_num', 10)):
break
# TODO: support prune input_spec
_, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir, prune_input=False)
self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export Post-Quant model and saved in {}".format(save_dir))
def _flops(self, loader):
self.model.eval()
try:
import paddleslim
except Exception as e:
logger.warning(
'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
)
return
from paddleslim.analysis import dygraph_flops as flops
input_data = None
for data in loader:
input_data = data
break
input_spec = [{
"image": input_data['image'][0].unsqueeze(0),
"im_shape": input_data['im_shape'][0].unsqueeze(0),
"scale_factor": input_data['scale_factor'][0].unsqueeze(0)
}]
flops = flops(self.model, input_spec) / (1000**3)
logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
flops, input_data['image'][0].unsqueeze(0).shape))