You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
239 lines
7.4 KiB
239 lines
7.4 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import sys |
|
import os |
|
import os.path as osp |
|
import time |
|
import math |
|
import imghdr |
|
import chardet |
|
import json |
|
|
|
import numpy as np |
|
|
|
from . import logging |
|
import platform |
|
import paddlers |
|
|
|
|
|
def seconds_to_hms(seconds): |
|
h = math.floor(seconds / 3600) |
|
m = math.floor((seconds - h * 3600) / 60) |
|
s = int(seconds - h * 3600 - m * 60) |
|
hms_str = "{}:{}:{}".format(h, m, s) |
|
return hms_str |
|
|
|
|
|
def get_encoding(path): |
|
f = open(path, 'rb') |
|
data = f.read() |
|
file_encoding = chardet.detect(data).get('encoding') |
|
f.close() |
|
return file_encoding |
|
|
|
|
|
def get_single_card_bs(batch_size): |
|
card_num = paddlers.env_info['num'] |
|
place = paddlers.env_info['place'] |
|
if batch_size % card_num == 0: |
|
return int(batch_size // card_num) |
|
elif batch_size == 1: |
|
# Evaluation of detection task only supports single card with batch size 1 |
|
return batch_size |
|
else: |
|
raise Exception("Please support correct batch_size, \ |
|
which can be divided by available cards({}) in {}" |
|
.format(card_num, place)) |
|
|
|
|
|
def dict2str(dict_input): |
|
out = '' |
|
for k, v in dict_input.items(): |
|
try: |
|
v = '{:8.6f}'.format(float(v)) |
|
except: |
|
pass |
|
out = out + '{}={}, '.format(k, v) |
|
return out.strip(', ') |
|
|
|
|
|
def norm_path(path): |
|
win_sep = "\\" |
|
other_sep = "/" |
|
if platform.system() == "Windows": |
|
path = win_sep.join(path.split(other_sep)) |
|
else: |
|
path = other_sep.join(path.split(win_sep)) |
|
return path |
|
|
|
|
|
def is_pic(img_path): |
|
valid_suffix = [ |
|
'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy' |
|
] |
|
suffix = img_path.split('.')[-1] |
|
if suffix in valid_suffix: |
|
return True |
|
img_format = imghdr.what(img_path) |
|
_, ext = osp.splitext(img_path) |
|
if img_format == 'tiff' or ext == '.img': |
|
return True |
|
return False |
|
|
|
|
|
class MyEncoder(json.JSONEncoder): |
|
def default(self, obj): |
|
if isinstance(obj, np.integer): |
|
return int(obj) |
|
elif isinstance(obj, np.floating): |
|
return float(obj) |
|
elif isinstance(obj, np.ndarray): |
|
return obj.tolist() |
|
else: |
|
return super(MyEncoder, self).default(obj) |
|
|
|
|
|
class EarlyStop: |
|
def __init__(self, patience, thresh): |
|
self.patience = patience |
|
self.counter = 0 |
|
self.score = None |
|
self.max = 0 |
|
self.thresh = thresh |
|
if patience < 1: |
|
raise Exception("Argument patience should be a positive integer.") |
|
|
|
def __call__(self, current_score): |
|
if self.score is None: |
|
self.score = current_score |
|
return False |
|
elif current_score > self.max: |
|
self.counter = 0 |
|
self.score = current_score |
|
self.max = current_score |
|
return False |
|
else: |
|
if (abs(self.score - current_score) < self.thresh or |
|
current_score < self.score): |
|
self.counter += 1 |
|
self.score = current_score |
|
logging.debug("EarlyStopping: %i / %i" % |
|
(self.counter, self.patience)) |
|
if self.counter >= self.patience: |
|
logging.info("EarlyStopping: Stop training") |
|
return True |
|
return False |
|
else: |
|
self.counter = 0 |
|
self.score = current_score |
|
return False |
|
|
|
|
|
class DisablePrint(object): |
|
def __enter__(self): |
|
self._original_stdout = sys.stdout |
|
sys.stdout = open(os.devnull, 'w') |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
sys.stdout.close() |
|
sys.stdout = self._original_stdout |
|
|
|
|
|
class Times(object): |
|
def __init__(self): |
|
self.time = 0. |
|
# start time |
|
self.st = 0. |
|
# end time |
|
self.et = 0. |
|
|
|
def start(self): |
|
self.st = time.time() |
|
|
|
def end(self, iter_num=1, accumulative=True): |
|
self.et = time.time() |
|
if accumulative: |
|
self.time += (self.et - self.st) / iter_num |
|
else: |
|
self.time = (self.et - self.st) / iter_num |
|
|
|
def reset(self): |
|
self.time = 0. |
|
self.st = 0. |
|
self.et = 0. |
|
|
|
def value(self): |
|
return round(self.time, 4) |
|
|
|
|
|
class Timer(Times): |
|
def __init__(self): |
|
super(Timer, self).__init__() |
|
self.preprocess_time_s = Times() |
|
self.inference_time_s = Times() |
|
self.postprocess_time_s = Times() |
|
self.img_num = 0 |
|
self.repeats = 0 |
|
|
|
def info(self, average=False): |
|
total_time = self.preprocess_time_s.value( |
|
) * self.img_num + self.inference_time_s.value( |
|
) + self.postprocess_time_s.value() * self.img_num |
|
total_time = round(total_time, 4) |
|
print("------------------ Inference Time Info ----------------------") |
|
print("total_time(ms): {}, img_num: {}, batch_size: {}".format( |
|
total_time * 1000, self.img_num, self.img_num)) |
|
preprocess_time = round( |
|
self.preprocess_time_s.value() / self.repeats, |
|
4) if average else self.preprocess_time_s.value() |
|
postprocess_time = round( |
|
self.postprocess_time_s.value() / self.repeats, |
|
4) if average else self.postprocess_time_s.value() |
|
inference_time = round(self.inference_time_s.value() / self.repeats, |
|
4) if average else self.inference_time_s.value() |
|
|
|
average_latency = total_time / self.repeats |
|
print("average latency time(ms): {:.2f}, QPS: {:2f}".format( |
|
average_latency * 1000, 1 / average_latency)) |
|
print("preprocess_time_per_im(ms): {:.2f}, " |
|
"inference_time_per_batch(ms): {:.2f}, " |
|
"postprocess_time_per_im(ms): {:.2f}".format( |
|
preprocess_time * 1000, inference_time * 1000, |
|
postprocess_time * 1000)) |
|
|
|
def report(self, average=False): |
|
dic = {} |
|
dic['preprocess_time_s'] = round( |
|
self.preprocess_time_s.value() / self.repeats, |
|
4) if average else self.preprocess_time_s.value() |
|
dic['postprocess_time_s'] = round( |
|
self.postprocess_time_s.value() / self.repeats, |
|
4) if average else self.postprocess_time_s.value() |
|
dic['inference_time_s'] = round( |
|
self.inference_time_s.value() / self.repeats, |
|
4) if average else self.inference_time_s.value() |
|
dic['img_num'] = self.img_num |
|
total_time = self.preprocess_time_s.value( |
|
) + self.inference_time_s.value() + self.postprocess_time_s.value() |
|
dic['total_time_s'] = round(total_time, 4) |
|
dic['batch_size'] = self.img_num / self.repeats |
|
return dic |
|
|
|
def reset(self): |
|
self.preprocess_time_s.reset() |
|
self.inference_time_s.reset() |
|
self.postprocess_time_s.reset() |
|
self.img_num = 0 |
|
self.repeats = 0
|
|
|