|
|
|
@ -15,7 +15,6 @@ |
|
|
|
|
import math |
|
|
|
|
import os.path as osp |
|
|
|
|
import numpy as np |
|
|
|
|
import cv2 |
|
|
|
|
from collections import OrderedDict |
|
|
|
|
import paddle |
|
|
|
|
import paddle.nn.functional as F |
|
|
|
@ -26,8 +25,10 @@ from paddlers.transforms import arrange_transforms |
|
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint |
|
|
|
|
import paddlers.utils.logging as logging |
|
|
|
|
from .base import BaseModel |
|
|
|
|
from .utils import seg_metrics as metrics |
|
|
|
|
from paddlers.utils.checkpoint import seg_pretrain_weights_dict |
|
|
|
|
from paddlers.models.ppcls.metric import build_metrics |
|
|
|
|
from paddlers.models.ppcls.loss import build_loss |
|
|
|
|
from paddlers.models.ppcls.data.postprocess import build_postprocess |
|
|
|
|
from paddlers.utils.checkpoint import imagenet_weights |
|
|
|
|
from paddlers.transforms import Decode, Resize |
|
|
|
|
|
|
|
|
|
__all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"] |
|
|
|
@ -49,8 +50,10 @@ class BaseClassifier(BaseModel): |
|
|
|
|
self.model_name = model_name |
|
|
|
|
self.num_classes = num_classes |
|
|
|
|
self.use_mixed_loss = use_mixed_loss |
|
|
|
|
self.metrics = None |
|
|
|
|
self.losses = None |
|
|
|
|
self.labels = None |
|
|
|
|
self._postprocess = None |
|
|
|
|
if params.get('with_net', True): |
|
|
|
|
params.pop('with_net', None) |
|
|
|
|
self.net = self.build_net(**params) |
|
|
|
@ -97,95 +100,35 @@ class BaseClassifier(BaseModel): |
|
|
|
|
] |
|
|
|
|
return input_spec |
|
|
|
|
|
|
|
|
|
# FIXME: use ppcls instead of ppseg, in infet / metrics and etc. |
|
|
|
|
def run(self, net, inputs, mode): |
|
|
|
|
net_out = net(inputs[0]) |
|
|
|
|
logit = net_out[0] |
|
|
|
|
label = paddle.to_tensor(inputs[1], dtype="int64") |
|
|
|
|
outputs = OrderedDict() |
|
|
|
|
if mode == 'test': |
|
|
|
|
origin_shape = inputs[1] |
|
|
|
|
if self.status == 'Infer': |
|
|
|
|
label_map_list, score_map_list = self._postprocess( |
|
|
|
|
net_out, origin_shape, transforms=inputs[2]) |
|
|
|
|
else: |
|
|
|
|
logit_list = self._postprocess( |
|
|
|
|
logit, origin_shape, transforms=inputs[2]) |
|
|
|
|
label_map_list = [] |
|
|
|
|
score_map_list = [] |
|
|
|
|
for logit in logit_list: |
|
|
|
|
logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC |
|
|
|
|
label_map_list.append( |
|
|
|
|
paddle.argmax( |
|
|
|
|
logit, axis=-1, keepdim=False, dtype='int32') |
|
|
|
|
.squeeze().numpy()) |
|
|
|
|
score_map_list.append( |
|
|
|
|
F.softmax( |
|
|
|
|
logit, axis=-1).squeeze().numpy().astype( |
|
|
|
|
'float32')) |
|
|
|
|
outputs['label_map'] = label_map_list |
|
|
|
|
outputs['score_map'] = score_map_list |
|
|
|
|
result = self._postprocess(net_out) |
|
|
|
|
outputs = result[0] |
|
|
|
|
|
|
|
|
|
if mode == 'eval': |
|
|
|
|
if self.status == 'Infer': |
|
|
|
|
pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW |
|
|
|
|
else: |
|
|
|
|
pred = paddle.argmax( |
|
|
|
|
logit, axis=1, keepdim=True, dtype='int32') |
|
|
|
|
label = inputs[1] |
|
|
|
|
origin_shape = [label.shape[-2:]] |
|
|
|
|
pred = self._postprocess( |
|
|
|
|
pred, origin_shape, transforms=inputs[2])[0] # NCHW |
|
|
|
|
intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area( |
|
|
|
|
pred, label, self.num_classes) |
|
|
|
|
outputs['intersect_area'] = intersect_area |
|
|
|
|
outputs['pred_area'] = pred_area |
|
|
|
|
outputs['label_area'] = label_area |
|
|
|
|
outputs['conf_mat'] = metrics.confusion_matrix(pred, label, |
|
|
|
|
self.num_classes) |
|
|
|
|
# print(self._postprocess(net_out)[0]) # for test |
|
|
|
|
label = paddle.unsqueeze(label, axis=-1) |
|
|
|
|
metric_dict = self.metrics(net_out, label) |
|
|
|
|
outputs['top1'] = metric_dict["top1"] |
|
|
|
|
outputs['top5'] = metric_dict["top5"] |
|
|
|
|
|
|
|
|
|
if mode == 'train': |
|
|
|
|
loss_list = metrics.loss_computation( |
|
|
|
|
logits_list=net_out, labels=inputs[1], losses=self.losses) |
|
|
|
|
loss = sum(loss_list) |
|
|
|
|
outputs['loss'] = loss |
|
|
|
|
loss_list = self.losses(net_out, label) |
|
|
|
|
outputs['loss'] = loss_list['loss'] |
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
# FIXME: use ppcls instead of ppseg, in loss. |
|
|
|
|
def default_metric(self): |
|
|
|
|
# TODO: other metrics |
|
|
|
|
default_config = [{"TopkAcc":{"topk": [1, 5]}}] |
|
|
|
|
return build_metrics(default_config) |
|
|
|
|
|
|
|
|
|
def default_loss(self): |
|
|
|
|
if isinstance(self.use_mixed_loss, bool): |
|
|
|
|
if self.use_mixed_loss: |
|
|
|
|
losses = [ |
|
|
|
|
paddleseg.models.CrossEntropyLoss(), |
|
|
|
|
paddleseg.models.LovaszSoftmaxLoss() |
|
|
|
|
] |
|
|
|
|
coef = [.8, .2] |
|
|
|
|
loss_type = [ |
|
|
|
|
paddleseg.models.MixedLoss( |
|
|
|
|
losses=losses, coef=coef), |
|
|
|
|
] |
|
|
|
|
else: |
|
|
|
|
loss_type = [paddleseg.models.CrossEntropyLoss()] |
|
|
|
|
else: |
|
|
|
|
losses, coef = list(zip(*self.use_mixed_loss)) |
|
|
|
|
if not set(losses).issubset( |
|
|
|
|
['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']): |
|
|
|
|
raise ValueError( |
|
|
|
|
"Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." |
|
|
|
|
) |
|
|
|
|
losses = [getattr(paddleseg.models, loss)() for loss in losses] |
|
|
|
|
loss_type = [ |
|
|
|
|
paddleseg.models.MixedLoss( |
|
|
|
|
losses=losses, coef=list(coef)) |
|
|
|
|
] |
|
|
|
|
if self.model_name == 'FastSCNN': |
|
|
|
|
loss_type *= 2 |
|
|
|
|
loss_coef = [1.0, 0.4] |
|
|
|
|
elif self.model_name == 'BiSeNetV2': |
|
|
|
|
loss_type *= 5 |
|
|
|
|
loss_coef = [1.0] * 5 |
|
|
|
|
else: |
|
|
|
|
loss_coef = [1.0] |
|
|
|
|
losses = {'types': loss_type, 'coef': loss_coef} |
|
|
|
|
return losses |
|
|
|
|
# TODO: mixed_loss |
|
|
|
|
default_config = [{"CELoss":{"weight": 1.0}}] |
|
|
|
|
return build_loss(default_config) |
|
|
|
|
|
|
|
|
|
def default_optimizer(self, |
|
|
|
|
parameters, |
|
|
|
@ -203,6 +146,14 @@ class BaseClassifier(BaseModel): |
|
|
|
|
weight_decay=4e-5) |
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
|
def default_postprocess(self, class_id_map_file): |
|
|
|
|
default_config = { |
|
|
|
|
"name": "Topk", |
|
|
|
|
"topk": 1, |
|
|
|
|
"class_id_map_file": class_id_map_file |
|
|
|
|
} |
|
|
|
|
return build_postprocess(default_config) |
|
|
|
|
|
|
|
|
|
def train(self, |
|
|
|
|
num_epochs, |
|
|
|
|
train_dataset, |
|
|
|
@ -212,7 +163,7 @@ class BaseClassifier(BaseModel): |
|
|
|
|
save_interval_epochs=1, |
|
|
|
|
log_interval_steps=2, |
|
|
|
|
save_dir='output', |
|
|
|
|
pretrain_weights='CITYSCAPES', # FIXME: fix clas's pretrain weights |
|
|
|
|
pretrain_weights='IMAGENET', |
|
|
|
|
learning_rate=0.01, |
|
|
|
|
lr_decay_power=0.9, |
|
|
|
|
early_stop=False, |
|
|
|
@ -255,6 +206,9 @@ class BaseClassifier(BaseModel): |
|
|
|
|
self.labels = train_dataset.labels |
|
|
|
|
if self.losses is None: |
|
|
|
|
self.losses = self.default_loss() |
|
|
|
|
self.metrics = self.default_metric() |
|
|
|
|
self._postprocess = self.default_postprocess(train_dataset.label_list) |
|
|
|
|
# print(self._postprocess.class_id_map) |
|
|
|
|
|
|
|
|
|
if optimizer is None: |
|
|
|
|
num_steps_each_epoch = train_dataset.num_samples // train_batch_size |
|
|
|
@ -265,7 +219,7 @@ class BaseClassifier(BaseModel): |
|
|
|
|
self.optimizer = optimizer |
|
|
|
|
|
|
|
|
|
if pretrain_weights is not None and not osp.exists(pretrain_weights): |
|
|
|
|
if pretrain_weights not in seg_pretrain_weights_dict[ |
|
|
|
|
if pretrain_weights not in imagenet_weights[ |
|
|
|
|
self.model_name]: |
|
|
|
|
logging.warning( |
|
|
|
|
"Path of pretrain_weights('{}') does not exist!".format( |
|
|
|
@ -273,9 +227,9 @@ class BaseClassifier(BaseModel): |
|
|
|
|
logging.warning("Pretrain_weights is forcibly set to '{}'. " |
|
|
|
|
"If don't want to use pretrain weights, " |
|
|
|
|
"set pretrain_weights to be None.".format( |
|
|
|
|
seg_pretrain_weights_dict[self.model_name][ |
|
|
|
|
imagenet_weights[self.model_name][ |
|
|
|
|
0])) |
|
|
|
|
pretrain_weights = seg_pretrain_weights_dict[self.model_name][ |
|
|
|
|
pretrain_weights = imagenet_weights[self.model_name][ |
|
|
|
|
0] |
|
|
|
|
elif pretrain_weights is not None and osp.exists(pretrain_weights): |
|
|
|
|
if osp.splitext(pretrain_weights)[-1] != '.pdparams': |
|
|
|
@ -370,12 +324,8 @@ class BaseClassifier(BaseModel): |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
collections.OrderedDict with key-value pairs: |
|
|
|
|
{"miou": `mean intersection over union`, |
|
|
|
|
"category_iou": `category-wise mean intersection over union`, |
|
|
|
|
"oacc": `overall accuracy`, |
|
|
|
|
"category_acc": `category-wise accuracy`, |
|
|
|
|
"kappa": ` kappa coefficient`, |
|
|
|
|
"category_F1-score": `F1 score`}. |
|
|
|
|
{"top1": `acc of top1`, |
|
|
|
|
"top5": `acc of top5`}. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
arrange_transforms( |
|
|
|
@ -403,73 +353,26 @@ class BaseClassifier(BaseModel): |
|
|
|
|
self.eval_data_loader = self.build_data_loader( |
|
|
|
|
eval_dataset, batch_size=batch_size, mode='eval') |
|
|
|
|
|
|
|
|
|
intersect_area_all = 0 |
|
|
|
|
pred_area_all = 0 |
|
|
|
|
label_area_all = 0 |
|
|
|
|
conf_mat_all = [] |
|
|
|
|
logging.info( |
|
|
|
|
"Start to evaluate(total_samples={}, total_steps={})...".format( |
|
|
|
|
eval_dataset.num_samples, |
|
|
|
|
math.ceil(eval_dataset.num_samples * 1.0 / batch_size))) |
|
|
|
|
|
|
|
|
|
top1s = [] |
|
|
|
|
top5s = [] |
|
|
|
|
with paddle.no_grad(): |
|
|
|
|
for step, data in enumerate(self.eval_data_loader): |
|
|
|
|
data.append(eval_dataset.transforms.transforms) |
|
|
|
|
outputs = self.run(self.net, data, 'eval') |
|
|
|
|
pred_area = outputs['pred_area'] |
|
|
|
|
label_area = outputs['label_area'] |
|
|
|
|
intersect_area = outputs['intersect_area'] |
|
|
|
|
conf_mat = outputs['conf_mat'] |
|
|
|
|
|
|
|
|
|
# Gather from all ranks |
|
|
|
|
if nranks > 1: |
|
|
|
|
intersect_area_list = [] |
|
|
|
|
pred_area_list = [] |
|
|
|
|
label_area_list = [] |
|
|
|
|
conf_mat_list = [] |
|
|
|
|
paddle.distributed.all_gather(intersect_area_list, |
|
|
|
|
intersect_area) |
|
|
|
|
paddle.distributed.all_gather(pred_area_list, pred_area) |
|
|
|
|
paddle.distributed.all_gather(label_area_list, label_area) |
|
|
|
|
paddle.distributed.all_gather(conf_mat_list, conf_mat) |
|
|
|
|
|
|
|
|
|
# Some image has been evaluated and should be eliminated in last iter |
|
|
|
|
if (step + 1) * nranks > len(eval_dataset): |
|
|
|
|
valid = len(eval_dataset) - step * nranks |
|
|
|
|
intersect_area_list = intersect_area_list[:valid] |
|
|
|
|
pred_area_list = pred_area_list[:valid] |
|
|
|
|
label_area_list = label_area_list[:valid] |
|
|
|
|
conf_mat_list = conf_mat_list[:valid] |
|
|
|
|
|
|
|
|
|
intersect_area_all += sum(intersect_area_list) |
|
|
|
|
pred_area_all += sum(pred_area_list) |
|
|
|
|
label_area_all += sum(label_area_list) |
|
|
|
|
conf_mat_all.extend(conf_mat_list) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
intersect_area_all = intersect_area_all + intersect_area |
|
|
|
|
pred_area_all = pred_area_all + pred_area |
|
|
|
|
label_area_all = label_area_all + label_area |
|
|
|
|
conf_mat_all.append(conf_mat) |
|
|
|
|
# FIXME: fix metrics |
|
|
|
|
class_iou, miou = paddleseg.utils.metrics.mean_iou( |
|
|
|
|
intersect_area_all, pred_area_all, label_area_all) |
|
|
|
|
# TODO 确认是按oacc还是macc |
|
|
|
|
class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all, |
|
|
|
|
pred_area_all) |
|
|
|
|
kappa = paddleseg.utils.metrics.kappa(intersect_area_all, |
|
|
|
|
pred_area_all, label_area_all) |
|
|
|
|
category_f1score = metrics.f1_score(intersect_area_all, pred_area_all, |
|
|
|
|
label_area_all) |
|
|
|
|
eval_metrics = OrderedDict( |
|
|
|
|
zip([ |
|
|
|
|
'miou', 'category_iou', 'oacc', 'category_acc', 'kappa', |
|
|
|
|
'category_F1-score' |
|
|
|
|
], [miou, class_iou, oacc, class_acc, kappa, category_f1score])) |
|
|
|
|
top1s.append(outputs["top1"]) |
|
|
|
|
top5s.append(outputs["top5"]) |
|
|
|
|
|
|
|
|
|
top1 = np.mean(top1s) |
|
|
|
|
top5 = np.mean(top5s) |
|
|
|
|
eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5])) |
|
|
|
|
if return_details: |
|
|
|
|
conf_mat = sum(conf_mat_all) |
|
|
|
|
eval_details = {'confusion_matrix': conf_mat.tolist()} |
|
|
|
|
return eval_metrics, eval_details |
|
|
|
|
# TODO: add details |
|
|
|
|
return eval_metrics, None |
|
|
|
|
return eval_metrics |
|
|
|
|
|
|
|
|
|
def predict(self, img_file, transforms=None): |
|
|
|
@ -485,10 +388,11 @@ class BaseClassifier(BaseModel): |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
If img_file is a string or np.array, the result is a dict with key-value pairs: |
|
|
|
|
{"label map": `label map`, "score_map": `score map`}. |
|
|
|
|
{"label map": `class_ids_map`, "scores_map": `label_names_map`}. |
|
|
|
|
If img_file is a list, the result is a list composed of dicts with the corresponding fields: |
|
|
|
|
label_map(np.ndarray): the predicted label map (HW) |
|
|
|
|
score_map(np.ndarray): the prediction score map (HWC) |
|
|
|
|
class_ids_map(np.ndarray): class_ids |
|
|
|
|
scores_map(np.ndarray): scores |
|
|
|
|
label_names_map(np.ndarray): label_names |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
if transforms is None and not hasattr(self, 'test_transforms'): |
|
|
|
@ -504,21 +408,23 @@ class BaseClassifier(BaseModel): |
|
|
|
|
self.net.eval() |
|
|
|
|
data = (batch_im, batch_origin_shape, transforms.transforms) |
|
|
|
|
outputs = self.run(self.net, data, 'test') |
|
|
|
|
label_map_list = outputs['label_map'] |
|
|
|
|
score_map_list = outputs['score_map'] |
|
|
|
|
label_list = outputs['class_ids'] |
|
|
|
|
score_list = outputs['scores'] |
|
|
|
|
name_list = outputs['label_names'] |
|
|
|
|
if isinstance(img_file, list): |
|
|
|
|
prediction = [{ |
|
|
|
|
'label_map': l, |
|
|
|
|
'score_map': s |
|
|
|
|
} for l, s in zip(label_map_list, score_map_list)] |
|
|
|
|
'class_ids_map': l, |
|
|
|
|
'scores_map': s, |
|
|
|
|
'label_names_map': n, |
|
|
|
|
} for l, s, n in zip(label_list, score_list, name_list)] |
|
|
|
|
else: |
|
|
|
|
prediction = { |
|
|
|
|
'label_map': label_map_list[0], |
|
|
|
|
'score_map': score_map_list[0] |
|
|
|
|
'class_ids': label_list[0], |
|
|
|
|
'scores': score_list[0], |
|
|
|
|
'label_names': name_list[0] |
|
|
|
|
} |
|
|
|
|
return prediction |
|
|
|
|
|
|
|
|
|
# FIXME: adaptive clas |
|
|
|
|
def _preprocess(self, images, transforms, to_tensor=True): |
|
|
|
|
arrange_transforms( |
|
|
|
|
model_type=self.model_type, transforms=transforms, mode='test') |
|
|
|
@ -587,84 +493,6 @@ class BaseClassifier(BaseModel): |
|
|
|
|
batch_restore_list.append(restore_list) |
|
|
|
|
return batch_restore_list |
|
|
|
|
|
|
|
|
|
# FIXME: adaptive clas |
|
|
|
|
def _postprocess(self, batch_pred, batch_origin_shape, transforms): |
|
|
|
|
batch_restore_list = BaseClassifier.get_transforms_shape_info( |
|
|
|
|
batch_origin_shape, transforms) |
|
|
|
|
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer': |
|
|
|
|
return self._infer_postprocess( |
|
|
|
|
batch_label_map=batch_pred[0], |
|
|
|
|
batch_score_map=batch_pred[1], |
|
|
|
|
batch_restore_list=batch_restore_list) |
|
|
|
|
results = [] |
|
|
|
|
if batch_pred.dtype == paddle.float32: |
|
|
|
|
mode = 'bilinear' |
|
|
|
|
else: |
|
|
|
|
mode = 'nearest' |
|
|
|
|
for pred, restore_list in zip(batch_pred, batch_restore_list): |
|
|
|
|
pred = paddle.unsqueeze(pred, axis=0) |
|
|
|
|
for item in restore_list[::-1]: |
|
|
|
|
h, w = item[1][0], item[1][1] |
|
|
|
|
if item[0] == 'resize': |
|
|
|
|
pred = F.interpolate( |
|
|
|
|
pred, (h, w), mode=mode, data_format='NCHW') |
|
|
|
|
elif item[0] == 'padding': |
|
|
|
|
x, y = item[2] |
|
|
|
|
pred = pred[:, :, y:y + h, x:x + w] |
|
|
|
|
else: |
|
|
|
|
pass |
|
|
|
|
results.append(pred) |
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
# FIXME: adaptive clas |
|
|
|
|
def _infer_postprocess(self, batch_label_map, batch_score_map, |
|
|
|
|
batch_restore_list): |
|
|
|
|
label_maps = [] |
|
|
|
|
score_maps = [] |
|
|
|
|
for label_map, score_map, restore_list in zip( |
|
|
|
|
batch_label_map, batch_score_map, batch_restore_list): |
|
|
|
|
if not isinstance(label_map, np.ndarray): |
|
|
|
|
label_map = paddle.unsqueeze(label_map, axis=[0, 3]) |
|
|
|
|
score_map = paddle.unsqueeze(score_map, axis=0) |
|
|
|
|
for item in restore_list[::-1]: |
|
|
|
|
h, w = item[1][0], item[1][1] |
|
|
|
|
if item[0] == 'resize': |
|
|
|
|
if isinstance(label_map, np.ndarray): |
|
|
|
|
label_map = cv2.resize( |
|
|
|
|
label_map, (w, h), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
score_map = cv2.resize( |
|
|
|
|
score_map, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
else: |
|
|
|
|
label_map = F.interpolate( |
|
|
|
|
label_map, (h, w), |
|
|
|
|
mode='nearest', |
|
|
|
|
data_format='NHWC') |
|
|
|
|
score_map = F.interpolate( |
|
|
|
|
score_map, (h, w), |
|
|
|
|
mode='bilinear', |
|
|
|
|
data_format='NHWC') |
|
|
|
|
elif item[0] == 'padding': |
|
|
|
|
x, y = item[2] |
|
|
|
|
if isinstance(label_map, np.ndarray): |
|
|
|
|
label_map = label_map[..., y:y + h, x:x + w] |
|
|
|
|
score_map = score_map[..., y:y + h, x:x + w] |
|
|
|
|
else: |
|
|
|
|
label_map = label_map[:, :, y:y + h, x:x + w] |
|
|
|
|
score_map = score_map[:, :, y:y + h, x:x + w] |
|
|
|
|
else: |
|
|
|
|
pass |
|
|
|
|
label_map = label_map.squeeze() |
|
|
|
|
score_map = score_map.squeeze() |
|
|
|
|
if not isinstance(label_map, np.ndarray): |
|
|
|
|
label_map = label_map.numpy() |
|
|
|
|
score_map = score_map.numpy() |
|
|
|
|
label_maps.append(label_map.squeeze()) |
|
|
|
|
score_maps.append(score_map.squeeze()) |
|
|
|
|
return label_maps, score_maps |
|
|
|
|
|
|
|
|
|
__all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet50_vd(BaseClassifier): |
|
|
|
|
def __init__(self, |
|
|
|
|
num_classes=2, |
|
|
|
|