From 22631cf72cad41e303e0d0f4ae1e4df82de77e90 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Tue, 16 Aug 2022 11:08:28 +0800 Subject: [PATCH 1/2] Normalize comments --- paddlers/datasets/clas_dataset.py | 2 - paddlers/datasets/coco.py | 8 +-- paddlers/datasets/seg_dataset.py | 2 - paddlers/datasets/voc.py | 3 +- paddlers/deploy/predictor.py | 4 +- paddlers/rs_models/cd/changeformer.py | 18 +++--- paddlers/rs_models/cd/layers/pd_timm.py | 6 +- paddlers/rs_models/cd/stanet.py | 2 +- paddlers/rs_models/clas/condensenet_v2.py | 14 ++--- paddlers/rs_models/res/generators/rcan.py | 12 ++-- paddlers/tasks/base.py | 20 +++---- paddlers/tasks/change_detector.py | 46 +++++++--------- paddlers/tasks/classifier.py | 12 ++-- paddlers/tasks/object_detector.py | 18 +++--- paddlers/tasks/segmenter.py | 55 ++++++++----------- .../tasks/utils/det_metrics/coco_utils.py | 23 ++++---- paddlers/tasks/utils/det_metrics/metrics.py | 8 +-- paddlers/tasks/utils/visualize.py | 4 +- paddlers/transforms/batch_operators.py | 18 +++--- paddlers/transforms/functions.py | 18 ++---- paddlers/transforms/operators.py | 14 ++--- paddlers/utils/checkpoint.py | 2 +- paddlers/utils/download.py | 5 +- paddlers/utils/utils.py | 4 +- tests/testing_utils.py | 2 +- tools/coco2mask.py | 2 +- tools/coco_tools/json_AnnoSta.py | 2 +- tools/coco_tools/json_Img2Json.py | 2 +- tools/coco_tools/json_ImgSta.py | 2 +- tools/coco_tools/json_InfoShow.py | 2 +- tools/coco_tools/json_Merge.py | 2 +- tools/coco_tools/json_Split.py | 2 +- tools/raster2geotiff.py | 11 ++-- tools/utils/__init__.py | 2 +- 34 files changed, 158 insertions(+), 189 deletions(-) diff --git a/paddlers/datasets/clas_dataset.py b/paddlers/datasets/clas_dataset.py index 58c0926..cb5ec60 100644 --- a/paddlers/datasets/clas_dataset.py +++ b/paddlers/datasets/clas_dataset.py @@ -48,8 +48,6 @@ class ClasDataset(BaseDataset): self.file_list = list() self.labels = list() - # TODO:非None时,让用户跳转数据集分析生成label_list - # 不要在此处分析label file if label_list is not None: with open(label_list, encoding=get_encoding(label_list)) as f: for line in f: diff --git a/paddlers/datasets/coco.py b/paddlers/datasets/coco.py index 1b2b9ac..f80ae13 100644 --- a/paddlers/datasets/coco.py +++ b/paddlers/datasets/coco.py @@ -58,11 +58,11 @@ class COCODetDataset(BaseDataset): allow_empty=False, empty_ratio=1.): # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, - # or matplotlib.backends is imported for the first time - # pycocotools import matplotlib + # or matplotlib.backends is imported for the first time. import matplotlib matplotlib.use('Agg') from pycocotools.coco import COCO + super(COCODetDataset, self).__init__(data_dir, label_list, transforms, num_workers, shuffle) @@ -159,7 +159,7 @@ class COCODetDataset(BaseDataset): difficults = [] for inst in instances: - # check gt bbox + # Check gt bbox if inst.get('ignore', False): continue if 'bbox' not in inst.keys(): @@ -168,7 +168,7 @@ class COCODetDataset(BaseDataset): if not any(np.array(inst['bbox'])): continue - # read box + # Read the box x1, y1, box_w, box_h = inst['bbox'] x2 = x1 + box_w y2 = y1 + box_h diff --git a/paddlers/datasets/seg_dataset.py b/paddlers/datasets/seg_dataset.py index 8496b8b..58ff2c6 100644 --- a/paddlers/datasets/seg_dataset.py +++ b/paddlers/datasets/seg_dataset.py @@ -49,8 +49,6 @@ class SegDataset(BaseDataset): self.file_list = list() self.labels = list() - # TODO:非None时,让用户跳转数据集分析生成label_list - # 不要在此处分析label file if label_list is not None: with open(label_list, encoding=get_encoding(label_list)) as f: for line in f: diff --git a/paddlers/datasets/voc.py b/paddlers/datasets/voc.py index 34f7396..d12defd 100644 --- a/paddlers/datasets/voc.py +++ b/paddlers/datasets/voc.py @@ -58,8 +58,7 @@ class VOCDetDataset(BaseDataset): allow_empty=False, empty_ratio=1.): # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, - # or matplotlib.backends is imported for the first time - # pycocotools import matplotlib + # or matplotlib.backends is imported for the first time. import matplotlib matplotlib.use('Agg') from pycocotools.coco import COCO diff --git a/paddlers/deploy/predictor.py b/paddlers/deploy/predictor.py index 5c39305..87cff22 100644 --- a/paddlers/deploy/predictor.py +++ b/paddlers/deploy/predictor.py @@ -97,7 +97,7 @@ class Predictor(object): osp.join(self.model_dir, 'model.pdiparams')) if use_gpu: - # 设置GPU初始显存(单位M)和Device ID + # Set memory on GPUs (in MB) and device ID config.enable_use_gpu(200, gpu_id) config.switch_ir_optim(True) if use_trt: @@ -127,7 +127,7 @@ class Predictor(object): ) else: try: - # cache 10 different shapes for mkldnn to avoid memory leak + # Cache 10 different shapes for mkldnn to avoid memory leak. config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() config.set_cpu_math_library_num_threads(mkl_thread_num) diff --git a/paddlers/rs_models/cd/changeformer.py b/paddlers/rs_models/cd/changeformer.py index e4b74ed..465cbb8 100644 --- a/paddlers/rs_models/cd/changeformer.py +++ b/paddlers/rs_models/cd/changeformer.py @@ -259,7 +259,7 @@ class EncoderTransformer_v3(nn.Layer): self.depths = depths self.embed_dims = embed_dims - # patch embedding definitions + # Patch embedding definitions self.patch_embed1 = OverlapPatchEmbed( img_size=img_size, patch_size=7, @@ -406,7 +406,7 @@ class EncoderTransformer_v3(nn.Layer): B = x.shape[0] outs = [] - # stage 1 + # Stage 1 x1, H1, W1 = self.patch_embed1(x) for i, blk in enumerate(self.block1): x1 = blk(x1, H1, W1) @@ -416,7 +416,7 @@ class EncoderTransformer_v3(nn.Layer): [0, 3, 1, 2]) outs.append(x1) - # stage 2 + # Stage 2 x1, H1, W1 = self.patch_embed2(x1) for i, blk in enumerate(self.block2): x1 = blk(x1, H1, W1) @@ -426,7 +426,7 @@ class EncoderTransformer_v3(nn.Layer): [0, 3, 1, 2]) outs.append(x1) - # stage 3 + # Stage 3 x1, H1, W1 = self.patch_embed3(x1) for i, blk in enumerate(self.block3): x1 = blk(x1, H1, W1) @@ -436,7 +436,7 @@ class EncoderTransformer_v3(nn.Layer): [0, 3, 1, 2]) outs.append(x1) - # stage 4 + # Stage 4 x1, H1, W1 = self.patch_embed4(x1) for i, blk in enumerate(self.block4): x1 = blk(x1, H1, W1) @@ -467,11 +467,11 @@ class DecoderTransformer_v3(nn.Layer): decoder_softmax=False, feature_strides=[2, 4, 8, 16]): super(DecoderTransformer_v3, self).__init__() - # assert + assert len(feature_strides) == len(in_channels) assert min(feature_strides) == feature_strides[0] - # settings + # Settings self.feature_strides = feature_strides self.input_transform = input_transform self.in_index = in_index @@ -491,7 +491,7 @@ class DecoderTransformer_v3(nn.Layer): self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) - # convolutional Difference Layers + # Convolutional Difference Layers self.diff_c4 = conv_diff( in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) self.diff_c3 = conv_diff( @@ -501,7 +501,7 @@ class DecoderTransformer_v3(nn.Layer): self.diff_c1 = conv_diff( in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) - # taking outputs from middle of the encoder + # Take outputs from middle of the encoder self.make_pred_c4 = make_prediction( in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c3 = make_prediction( diff --git a/paddlers/rs_models/cd/layers/pd_timm.py b/paddlers/rs_models/cd/layers/pd_timm.py index fbb9a5e..1502155 100644 --- a/paddlers/rs_models/cd/layers/pd_timm.py +++ b/paddlers/rs_models/cd/layers/pd_timm.py @@ -38,7 +38,7 @@ class DropPath(nn.Layer): Returns: output: output tensor after drop path """ - # if prob is 0 or eval mode, return original input + # If prob is 0 or is in eval mode, return original input. if self.drop_prob == 0. or not self.training: return inputs keep_prob = 1 - self.drop_prob @@ -47,8 +47,8 @@ class DropPath(nn.Layer): ) # shape=(N, 1, 1, 1) random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype) random_tensor = random_tensor.floor() # mask - output = inputs.divide( - keep_prob) * random_tensor # divide to keep same output expectation + # Make division to keep output expectation same. + output = inputs.divide(keep_prob) * random_tensor return output def forward(self, inputs): diff --git a/paddlers/rs_models/cd/stanet.py b/paddlers/rs_models/cd/stanet.py index f2ef649..423ba43 100644 --- a/paddlers/rs_models/cd/stanet.py +++ b/paddlers/rs_models/cd/stanet.py @@ -263,7 +263,7 @@ class PAMBlock(nn.Layer): def _attend(self, query, key, value): energy = paddle.bmm(query.transpose((0, 2, 1)), - key) # batch matrix multiplication + key) # Batched matrix multiplication energy = (self.key_ch**(-0.5)) * energy attention = F.softmax(energy, axis=-1) out = paddle.bmm(value, attention.transpose((0, 2, 1))) diff --git a/paddlers/rs_models/clas/condensenet_v2.py b/paddlers/rs_models/clas/condensenet_v2.py index 2ca1073..53bb5aa 100644 --- a/paddlers/rs_models/clas/condensenet_v2.py +++ b/paddlers/rs_models/clas/condensenet_v2.py @@ -87,11 +87,11 @@ class Conv(nn.Sequential): def ShuffleLayer(x, groups): batchsize, num_channels, height, width = x.shape channels_per_group = num_channels // groups - # reshape + # Reshape x = x.reshape((batchsize, groups, channels_per_group, height, width)) - # transpose + # Transpose x = x.transpose((0, 2, 1, 3, 4)) - # reshape + # Reshape x = x.reshape((batchsize, groups * channels_per_group, height, width)) return x @@ -99,11 +99,11 @@ def ShuffleLayer(x, groups): def ShuffleLayerTrans(x, groups): batchsize, num_channels, height, width = x.shape channels_per_group = num_channels // groups - # reshape + # Reshape x = x.reshape((batchsize, channels_per_group, groups, height, width)) - # transpose + # Transpose x = x.transpose((0, 2, 1, 3, 4)) - # reshape + # Reshape x = x.reshape((batchsize, channels_per_group * groups, height, width)) return x @@ -385,7 +385,7 @@ class CondenseNetV2(nn.Layer): return out def _initialize(self): - # initialize + # Initialize for m in self.sublayers(): if isinstance(m, nn.Conv2D): nn.initializer.KaimingNormal()(m.weight) diff --git a/paddlers/rs_models/res/generators/rcan.py b/paddlers/rs_models/res/generators/rcan.py index 9de30c7..17f9ee8 100644 --- a/paddlers/rs_models/res/generators/rcan.py +++ b/paddlers/rs_models/res/generators/rcan.py @@ -1,4 +1,4 @@ -# base on https://github.com/kongdebug/RCAN-Paddle +# Based on https://github.com/kongdebug/RCAN-Paddle import math import paddle @@ -37,9 +37,9 @@ class MeanShift(nn.Conv2D): class CALayer(nn.Layer): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() - # global average pooling: feature --> point + # Global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2D(1) - # feature channel downscale and upscale --> channel weight + # Feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2D( channel, channel // reduction, 1, padding=0, bias_attr=True), @@ -157,10 +157,10 @@ class RCAN(nn.Layer): rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std) - # define head module + # Define head module modules_head = [conv(n_colors, n_feats, kernel_size)] - # define body module + # Define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \ @@ -168,7 +168,7 @@ class RCAN(nn.Layer): modules_body.append(conv(n_feats, n_feats, kernel_size)) - # define tail module + # Define tail module modules_tail = [ Upsampler( conv, scale, n_feats, act=False), diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 86f29c5..5250058 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -76,10 +76,10 @@ class BaseModel(metaclass=ModelMeta): self.eval_metrics = None self.best_accuracy = -1. self.best_model_epoch = -1 - # 是否使用多卡间同步BatchNorm均值和方差 + # Whether to use synchronized BN self.sync_bn = False self.status = 'Normal' - # 已完成迭代轮数,为恢复训练时的起始轮数 + # The initial epoch when training is resumed self.completed_epochs = 0 self.pruner = None self.pruning_ratios = None @@ -239,7 +239,7 @@ class BaseModel(metaclass=ModelMeta): mode='w') as f: yaml.dump(model_info, f) - # 评估结果保存 + # Save evaluation details if hasattr(self, 'eval_details'): with open(osp.join(save_dir, 'eval_details.json'), 'w') as f: json.dump(self.eval_details, f) @@ -258,7 +258,7 @@ class BaseModel(metaclass=ModelMeta): mode='w') as f: yaml.dump(quant_info, f) - # 模型保存成功的标志 + # Success flag open(osp.join(save_dir, '.success'), 'w').close() logging.info("Model saved in {}.".format(save_dir)) @@ -391,7 +391,7 @@ class BaseModel(metaclass=ModelMeta): step_time_tic = step_time_toc current_step += 1 - # 每间隔log_interval_steps,输出loss信息 + # Log loss info every log_interval_steps if current_step % log_interval_steps == 0 and local_rank == 0: if use_vdl: for k, v in outputs.items(): @@ -399,7 +399,7 @@ class BaseModel(metaclass=ModelMeta): '{}-Metrics/Training(Step): {}'.format( task_id, k), v, current_step) - # 估算剩余时间 + # Estimation remaining time avg_step_time = train_step_time.avg() eta = avg_step_time * (train_total_step - current_step) if eval_dataset is not None: @@ -427,14 +427,14 @@ class BaseModel(metaclass=ModelMeta): self.net.set_state_dict(ema.apply()) eval_epoch_tic = time.time() - # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存 + # Every save_interval_epochs, evaluate and save the model if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1: if eval_dataset is not None and eval_dataset.num_samples > 0: eval_result = self.evaluate( eval_dataset, batch_size=eval_batch_size, return_details=True) - # 保存最优模型 + # Save the optimial model if local_rank == 0: self.eval_metrics, self.eval_details = eval_result if use_vdl: @@ -548,7 +548,7 @@ class BaseModel(metaclass=ModelMeta): "Exported inference model does not support quantization-aware training.", exit=True) if quant_config is None: - # default quantization configuration + # Default quantization configuration quant_config = { # {None, 'PACT'}. Weight preprocess type. If None, no preprocessing is performed. 'weight_preprocess_type': None, @@ -669,7 +669,7 @@ class BaseModel(metaclass=ModelMeta): mode='w') as f: yaml.dump(pipeline_info, f) - # 模型保存成功的标志 + # Success flag open(osp.join(save_dir, '.success'), 'w').close() logging.info("The inference model for deployment is saved in {}.". format(save_dir)) diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index fa4a127..c6bd1e9 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -25,9 +25,9 @@ import paddle.nn.functional as F from paddle.static import InputSpec import paddlers +import paddlers.models.ppseg as ppseg import paddlers.rs_models.cd as cmcd import paddlers.utils.logging as logging -import paddlers.models.ppseg as paddleseg from paddlers.transforms import Resize, decode_image from paddlers.utils import get_single_card_bs, DisablePrint from paddlers.utils.checkpoint import seg_pretrain_weights_dict @@ -144,7 +144,7 @@ class BaseChangeDetector(BaseModel): origin_shape = [label.shape[-2:]] pred = self._postprocess( pred, origin_shape, transforms=inputs[3])[0] # NCHW - intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area( + intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area( pred, label, self.num_classes) outputs['intersect_area'] = intersect_area outputs['pred_area'] = pred_area @@ -178,16 +178,13 @@ class BaseChangeDetector(BaseModel): if isinstance(self.use_mixed_loss, bool): if self.use_mixed_loss: losses = [ - paddleseg.models.CrossEntropyLoss(), - paddleseg.models.LovaszSoftmaxLoss() + ppseg.models.CrossEntropyLoss(), + ppseg.models.LovaszSoftmaxLoss() ] coef = [.8, .2] - loss_type = [ - paddleseg.models.MixedLoss( - losses=losses, coef=coef), - ] + loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ] else: - loss_type = [paddleseg.models.CrossEntropyLoss()] + loss_type = [ppseg.models.CrossEntropyLoss()] else: losses, coef = list(zip(*self.use_mixed_loss)) if not set(losses).issubset( @@ -195,11 +192,8 @@ class BaseChangeDetector(BaseModel): raise ValueError( "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." ) - losses = [getattr(paddleseg.models, loss)() for loss in losses] - loss_type = [ - paddleseg.models.MixedLoss( - losses=losses, coef=list(coef)) - ] + losses = [getattr(ppseg.models, loss)() for loss in losses] + loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))] loss_coef = [1.0] losses = {'types': loss_type, 'coef': loss_coef} return losses @@ -492,13 +486,13 @@ class BaseChangeDetector(BaseModel): pred_area_all = pred_area_all + pred_area label_area_all = label_area_all + label_area conf_mat_all.append(conf_mat) - class_iou, miou = paddleseg.utils.metrics.mean_iou( + class_iou, miou = ppseg.utils.metrics.mean_iou( intersect_area_all, pred_area_all, label_area_all) # TODO 确认是按oacc还是macc - class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all, - pred_area_all) - kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all, - label_area_all) + class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all, + pred_area_all) + kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all, + label_area_all) category_f1score = metrics.f1_score(intersect_area_all, pred_area_all, label_area_all) @@ -643,7 +637,7 @@ class BaseChangeDetector(BaseModel): int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) im2 = src2_data.ReadAsArray( int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) - # fill + # Fill h, w = im1.shape[:2] im1_fill = np.zeros( (block_size[1], block_size[0], bands), dtype=im1.dtype) @@ -651,10 +645,10 @@ class BaseChangeDetector(BaseModel): im1_fill[:h, :w, :] = im1 im2_fill[:h, :w, :] = im2 im_fill = (im1_fill, im2_fill) - # predict + # Predict pred = self.predict(im_fill, transforms)["label_map"].astype("uint8") - # overlap + # Overlap rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) mask = (rd_block == pred[:h, :w]) | (rd_block == 255) temp = pred[:h, :w].copy() @@ -966,7 +960,7 @@ class DSIFN(BaseChangeDetector): if self.use_mixed_loss is False: return { # XXX: make sure the shallow copy works correctly here. - 'types': [paddleseg.models.CrossEntropyLoss()] * 5, + 'types': [ppseg.models.CrossEntropyLoss()] * 5, 'coef': [1.0] * 5 } else: @@ -998,8 +992,8 @@ class DSAMNet(BaseChangeDetector): if self.use_mixed_loss is False: return { 'types': [ - paddleseg.models.CrossEntropyLoss(), - paddleseg.models.DiceLoss(), paddleseg.models.DiceLoss() + ppseg.models.CrossEntropyLoss(), ppseg.models.DiceLoss(), + ppseg.models.DiceLoss() ], 'coef': [1.0, 0.05, 0.05] } @@ -1034,7 +1028,7 @@ class ChangeStar(BaseChangeDetector): if self.use_mixed_loss is False: return { # XXX: make sure the shallow copy works correctly here. - 'types': [paddleseg.models.CrossEntropyLoss()] * 4, + 'types': [ppseg.models.CrossEntropyLoss()] * 4, 'coef': [1.0] * 4 } else: diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 7e6c109..ad39a5d 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -22,17 +22,17 @@ import paddle import paddle.nn.functional as F from paddle.static import InputSpec -import paddlers.models.ppcls as paddleclas -import paddlers.rs_models.clas as cmcls import paddlers -from paddlers.utils import get_single_card_bs, DisablePrint +import paddlers.models.ppcls as ppcls +import paddlers.rs_models.clas as cmcls import paddlers.utils.logging as logging -from .base import BaseModel +from paddlers.utils import get_single_card_bs, DisablePrint from paddlers.models.ppcls.metric import build_metrics from paddlers.models.ppcls.loss import build_loss from paddlers.models.ppcls.data.postprocess import build_postprocess from paddlers.utils.checkpoint import cls_pretrain_weights_dict from paddlers.transforms import Resize, decode_image +from .base import BaseModel __all__ = [ "ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C", "CondenseNetV2_b" @@ -50,7 +50,7 @@ class BaseClassifier(BaseModel): if 'with_net' in self.init_params: del self.init_params['with_net'] super(BaseClassifier, self).__init__('classifier') - if not hasattr(paddleclas.arch.backbone, model_name) and \ + if not hasattr(ppcls.arch.backbone, model_name) and \ not hasattr(cmcls, model_name): raise ValueError("ERROR: There is no model named {}.".format( model_name)) @@ -69,7 +69,7 @@ class BaseClassifier(BaseModel): def build_net(self, **params): with paddle.utils.unique_name.guard(): - model = dict(paddleclas.arch.backbone.__dict__, + model = dict(ppcls.arch.backbone.__dict__, **cmcls.__dict__)[self.model_name] # TODO: Determine whether there is in_channels try: diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index 4dc8ae5..f232b01 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import collections import copy import os @@ -23,18 +21,18 @@ import numpy as np import paddle from paddle.static import InputSpec +import paddlers import paddlers.models.ppdet as ppdet from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner -import paddlers -import paddlers.utils.logging as logging from paddlers.transforms import decode_image from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad from paddlers.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \ _BatchPad, _Gt2YoloTarget -from .base import BaseModel -from .utils.det_metrics import VOCMetric, COCOMetric from paddlers.models.ppdet.optimizer import ModelEMA +import paddlers.utils.logging as logging from paddlers.utils.checkpoint import det_pretrain_weights_dict +from .base import BaseModel +from .utils.det_metrics import VOCMetric, COCOMetric __all__ = [ "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN" @@ -291,7 +289,7 @@ class BaseDetector(BaseModel): train_dataset.batch_transforms = self._compose_batch_transform( train_dataset.transforms, mode='train') - # build optimizer if not defined + # Build optimizer if not defined if optimizer is None: num_steps_each_epoch = len(train_dataset) // train_batch_size self.optimizer = self.default_optimizer( @@ -305,7 +303,7 @@ class BaseDetector(BaseModel): else: self.optimizer = optimizer - # initiate weights + # Initiate weights if pretrain_weights is not None and not osp.exists(pretrain_weights): if pretrain_weights not in det_pretrain_weights_dict['_'.join( [self.model_name, self.backbone_name])]: @@ -335,7 +333,7 @@ class BaseDetector(BaseModel): ema = ModelEMA(model=self.net, decay=.9998, use_thres_step=True) else: ema = None - # start train loop + # Start train loop self.train_loop( num_epochs=num_epochs, train_dataset=train_dataset, @@ -822,7 +820,7 @@ class PicoDet(BaseDetector): if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)): if mode != 'train': raise ValueError( - "{} cannot be present in the {} transforms. ".format( + "{} cannot be present in the {} transforms.".format( op.__class__.__name__, mode) + "Please check the {} transforms.".format(mode)) custom_batch_transforms.insert(0, copy.deepcopy(op)) diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 900f481..da17a5c 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -23,15 +23,15 @@ import paddle import paddle.nn.functional as F from paddle.static import InputSpec -import paddlers.models.ppseg as paddleseg -import paddlers.rs_models.seg as cmseg import paddlers +import paddlers.models.ppseg as ppseg +import paddlers.rs_models.seg as cmseg from paddlers.utils import get_single_card_bs, DisablePrint import paddlers.utils.logging as logging -from .base import BaseModel -from .utils import seg_metrics as metrics from paddlers.utils.checkpoint import seg_pretrain_weights_dict from paddlers.transforms import Resize, decode_image +from .base import BaseModel +from .utils import seg_metrics as metrics __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -46,7 +46,7 @@ class BaseSegmenter(BaseModel): if 'with_net' in self.init_params: del self.init_params['with_net'] super(BaseSegmenter, self).__init__('segmenter') - if not hasattr(paddleseg.models, model_name) and \ + if not hasattr(ppseg.models, model_name) and \ not hasattr(cmseg, model_name): raise ValueError("ERROR: There is no model named {}.".format( model_name)) @@ -63,9 +63,8 @@ class BaseSegmenter(BaseModel): def build_net(self, **params): # TODO: when using paddle.utils.unique_name.guard, # DeepLabv3p and HRNet will raise a error - net = dict(paddleseg.models.__dict__, - **cmseg.__dict__)[self.model_name]( - num_classes=self.num_classes, **params) + net = dict(ppseg.models.__dict__, **cmseg.__dict__)[self.model_name]( + num_classes=self.num_classes, **params) return net def _fix_transforms_shape(self, image_shape): @@ -143,7 +142,7 @@ class BaseSegmenter(BaseModel): origin_shape = [label.shape[-2:]] pred = self._postprocess( pred, origin_shape, transforms=inputs[2])[0] # NCHW - intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area( + intersect_area, pred_area, label_area = ppseg.utils.metrics.calculate_area( pred, label, self.num_classes) outputs['intersect_area'] = intersect_area outputs['pred_area'] = pred_area @@ -161,16 +160,13 @@ class BaseSegmenter(BaseModel): if isinstance(self.use_mixed_loss, bool): if self.use_mixed_loss: losses = [ - paddleseg.models.CrossEntropyLoss(), - paddleseg.models.LovaszSoftmaxLoss() + ppseg.models.CrossEntropyLoss(), + ppseg.models.LovaszSoftmaxLoss() ] coef = [.8, .2] - loss_type = [ - paddleseg.models.MixedLoss( - losses=losses, coef=coef), - ] + loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ] else: - loss_type = [paddleseg.models.CrossEntropyLoss()] + loss_type = [ppseg.models.CrossEntropyLoss()] else: losses, coef = list(zip(*self.use_mixed_loss)) if not set(losses).issubset( @@ -178,11 +174,8 @@ class BaseSegmenter(BaseModel): raise ValueError( "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." ) - losses = [getattr(paddleseg.models, loss)() for loss in losses] - loss_type = [ - paddleseg.models.MixedLoss( - losses=losses, coef=list(coef)) - ] + losses = [getattr(ppseg.models, loss)() for loss in losses] + loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))] if self.model_name == 'FastSCNN': loss_type *= 2 loss_coef = [1.0, 0.4] @@ -475,13 +468,13 @@ class BaseSegmenter(BaseModel): pred_area_all = pred_area_all + pred_area label_area_all = label_area_all + label_area conf_mat_all.append(conf_mat) - class_iou, miou = paddleseg.utils.metrics.mean_iou( + class_iou, miou = ppseg.utils.metrics.mean_iou( intersect_area_all, pred_area_all, label_area_all) # TODO 确认是按oacc还是macc - class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all, - pred_area_all) - kappa = paddleseg.utils.metrics.kappa(intersect_area_all, pred_area_all, - label_area_all) + class_acc, oacc = ppseg.utils.metrics.accuracy(intersect_area_all, + pred_area_all) + kappa = ppseg.utils.metrics.kappa(intersect_area_all, pred_area_all, + label_area_all) category_f1score = metrics.f1_score(intersect_area_all, pred_area_all, label_area_all) eval_metrics = OrderedDict( @@ -613,15 +606,15 @@ class BaseSegmenter(BaseModel): ysize = int(height - yoff) im = src_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0)) - # fill + # Fill h, w = im.shape[:2] im_fill = np.zeros( (block_size[1], block_size[0], bands), dtype=im.dtype) im_fill[:h, :w, :] = im - # predict + # Predict pred = self.predict(im_fill, transforms)["label_map"].astype("uint8") - # overlap + # Overlap rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize) mask = (rd_block == pred[:h, :w]) | (rd_block == 255) temp = pred[:h, :w].copy() @@ -818,7 +811,7 @@ class DeepLabV3P(BaseSegmenter): "{'ResNet50_vd', 'ResNet101_vd'}.".format(backbone)) if params.get('with_net', True): with DisablePrint(): - backbone = getattr(paddleseg.models, backbone)( + backbone = getattr(ppseg.models, backbone)( input_channel=input_channel, output_stride=output_stride) else: backbone = None @@ -864,7 +857,7 @@ class HRNet(BaseSegmenter): self.backbone_name = 'HRNet_W{}'.format(width) if params.get('with_net', True): with DisablePrint(): - backbone = getattr(paddleseg.models, self.backbone_name)( + backbone = getattr(ppseg.models, self.backbone_name)( align_corners=align_corners) else: backbone = None diff --git a/paddlers/tasks/utils/det_metrics/coco_utils.py b/paddlers/tasks/utils/det_metrics/coco_utils.py index d70fce2..130e497 100644 --- a/paddlers/tasks/utils/det_metrics/coco_utils.py +++ b/paddlers/tasks/utils/det_metrics/coco_utils.py @@ -142,20 +142,19 @@ def cocoapi_eval(anns, logging.info('Per-category of {} AP: \n{}'.format(style, table.table)) logging.info("per-category PR curve has output to {} folder.".format( style + '_pr_curve')) - # flush coco evaluation result + # Flush coco evaluation result sys.stdout.flush() return coco_eval.stats def loadRes(coco_obj, anns): # This function has the same functionality as pycocotools.COCO.loadRes, - # except that the input anns is list of results rather than a json file. + # excepting that the input anns is list of results rather than a json file. # Refer to # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305, # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, - # or matplotlib.backends is imported for the first time - # pycocotools import matplotlib + # or matplotlib.backends is imported for the first time. import matplotlib matplotlib.use('Agg') from pycocotools.coco import COCO @@ -192,7 +191,7 @@ def loadRes(coco_obj, anns): res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[ 'categories']) for id, ann in enumerate(anns): - # now only support compressed RLE format as segmentation results + # Now only supports compressed RLE format as segmentation results. ann['area'] = maskUtils.area(ann['segmentation']) if not 'bbox' in ann: ann['bbox'] = maskUtils.toBbox(ann['segmentation']) @@ -291,8 +290,7 @@ def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type, areas=None): """ # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, - # or matplotlib.backends is imported for the first time - # pycocotools import matplotlib + # or matplotlib.backends is imported for the first time. import matplotlib matplotlib.use('Agg') from pycocotools.coco import COCO @@ -311,7 +309,7 @@ def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type, areas=None): select_dt_anns.append(ann) dt.dataset['annotations'] = select_dt_anns dt.createIndex() - # compute precision but ignore superclass confusion + # Compute precision but ignore superclass confusion. gt = copy.deepcopy(cocoGt) child_catIds = gt.getCatIds(supNms=[nm['supercategory']]) for idx, ann in enumerate(gt.dataset['annotations']): @@ -379,8 +377,7 @@ def coco_error_analysis(eval_details_file=None, import multiprocessing as mp # matplotlib.use() must be called *before* pylab, matplotlib.pyplot, - # or matplotlib.backends is imported for the first time - # pycocotools import matplotlib + # or matplotlib.backends is imported for the first time. import matplotlib matplotlib.use('Agg') from pycocotools.coco import COCO @@ -446,11 +443,11 @@ def coco_error_analysis(eval_details_file=None, assert k == analyze_result[0], "" ps_supercategory = analyze_result[1]['ps_supercategory'] ps_allcategory = analyze_result[1]['ps_allcategory'] - # compute precision but ignore superclass confusion + # Compute precision but ignore superclass confusion. ps[3, :, k, :, :] = ps_supercategory - # compute precision but ignore any class confusion + # Compute precision but ignore any class confusion. ps[4, :, k, :, :] = ps_allcategory - # fill in background and false negative errors and plot + # Fill in background and false negative errors and plot. ps[ps == -1] = 0 ps[5, :, k, :, :] = ps[4, :, k, :, :] > 0 ps[6, :, k, :, :] = 1.0 diff --git a/paddlers/tasks/utils/det_metrics/metrics.py b/paddlers/tasks/utils/det_metrics/metrics.py index 942b1b8..e818a8f 100644 --- a/paddlers/tasks/utils/det_metrics/metrics.py +++ b/paddlers/tasks/utils/det_metrics/metrics.py @@ -41,11 +41,11 @@ class Metric(paddle.metric.Metric): # paddle.metric.Metric defined :metch:`update`, :meth:`accumulate` # :metch:`reset`, in ppdet, we also need following 2 methods: - # abstract method for logging metric results + # Abstract method for logging metric results def log(self): pass - # abstract method for getting metric results + # Abstract method for getting metric results def get_results(self): pass @@ -162,7 +162,7 @@ class COCOMetric(Metric): self.reset() def reset(self): - # only bbox and mask evaluation support currently + # Only bbox and mask evaluation are supported currently. self.details = { 'gt': copy.deepcopy(self.coco_gt.dataset), 'bbox': [], @@ -172,7 +172,7 @@ class COCOMetric(Metric): def update(self, inputs, outputs): outs = {} - # outputs Tensor -> numpy.ndarray + # Tensor -> numpy.ndarray for k, v in outputs.items(): outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v diff --git a/paddlers/tasks/utils/visualize.py b/paddlers/tasks/utils/visualize.py index 43e5a65..313648d 100644 --- a/paddlers/tasks/utils/visualize.py +++ b/paddlers/tasks/utils/visualize.py @@ -309,7 +309,7 @@ def draw_pr_curve(eval_details_file=None, aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] if ap == 1: - # dimension of precision: [TxRxKxAxM] + # Dimension of precision: [TxRxKxAxM] s = coco_gt.eval['precision'] # IoU if iouThr is not None: @@ -317,7 +317,7 @@ def draw_pr_curve(eval_details_file=None, s = s[t] s = s[:, :, :, aind, mind] else: - # dimension of recall: [TxKxAxM] + # Dimension of recall: [TxKxAxM] s = coco_gt.eval['recall'] if iouThr is not None: t = np.where(iouThr == p.iouThrs)[0] diff --git a/paddlers/transforms/batch_operators.py b/paddlers/transforms/batch_operators.py index 12ad8a8..98c7f73 100644 --- a/paddlers/transforms/batch_operators.py +++ b/paddlers/transforms/batch_operators.py @@ -229,7 +229,7 @@ class _Gt2YoloTarget(Transform): if gw <= 0. or gh <= 0. or score <= 0.: continue - # find best match anchor index + # Find best matched anchor index best_iou = 0. best_idx = -1 for an_idx in range(an_hw.shape[0]): @@ -243,8 +243,8 @@ class _Gt2YoloTarget(Transform): gi = int(gx * grid_w) gj = int(gy * grid_h) - # gtbox should be regresed in this layes if best match - # anchor index in anchor mask of this layer + # gtbox should be regressed in this layer if best matched + # anchor index is in the anchor mask of this layer. if best_idx in mask: best_n = mask.index(best_idx) @@ -257,14 +257,14 @@ class _Gt2YoloTarget(Transform): gh * h / self.anchors[best_idx][1]) target[best_n, 4, gj, gi] = 2.0 - gw * gh - # objectness record gt_score + # Record gt_score target[best_n, 5, gj, gi] = score - # classification + # Do classification target[best_n, 6 + cls, gj, gi] = 1. # For non-matched anchors, calculate the target if the iou - # between anchor and gt is larger than iou_thresh + # between anchor and gt is larger than iou_thresh. if self.iou_thresh < 1: for idx, mask_i in enumerate(mask): if mask_i == best_idx: continue @@ -282,14 +282,14 @@ class _Gt2YoloTarget(Transform): gh * h / self.anchors[mask_i][1]) target[idx, 4, gj, gi] = 2.0 - gw * gh - # objectness record gt_score + # Record gt_score target[idx, 5, gj, gi] = score - # classification + # Do classification target[idx, 5 + cls, gj, gi] = 1. sample['target{}'.format(i)] = target - # remove useless gt_class and gt_score after target calculated + # Remove useless gt_class and gt_score items after target has been calculated. sample.pop('gt_class') sample.pop('gt_score') diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 467cf6e..12c3e9a 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -55,7 +55,6 @@ def center_crop(im, crop_size=224): return im -# region flip def img_flip(im, method=0): """ Flip an image. @@ -168,10 +167,6 @@ def lt2rb_flip(im): return im -# endregion - - -# region rotation def img_simple_rotate(im, method=0): """ Rotate an image. @@ -255,9 +250,6 @@ def rot_270(im): return im -# endregion - - def rgb2bgr(im): return im[:, :, ::-1] @@ -405,7 +397,7 @@ def to_uint8(im, is_linear=False): # 2% linear stretch def _two_percent_linear(image, max_out=255, min_out=0): def _gray_process(gray, maxout=max_out, minout=min_out): - # get the corresponding gray level at 98% histogram + # Get the corresponding gray level at 98% in the histogram. high_value = np.percentile(gray, 98) low_value = np.percentile(gray, 2) truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value) @@ -422,7 +414,7 @@ def to_uint8(im, is_linear=False): result = _gray_process(image) return np.uint8(result) - # simple image standardization + # Simple image standardization def _sample_norm(image): stretches = [] if len(image.shape) == 3: @@ -456,7 +448,7 @@ def to_intensity(im): if len(im.shape) != 2: raise ValueError("`len(im.shape) must be 2.") - # the type is complex means this is a SAR data + # If the type is complex, this is SAR data. if isinstance(type(im[0, 0]), complex): im = abs(im) return im @@ -475,7 +467,7 @@ def select_bands(im, band_list=[1, 2, 3]): np.ndarray: Image with selected bands. """ - if len(im.shape) == 2: # just have one channel + if len(im.shape) == 2: # Image has only one channel return im if not isinstance(band_list, list) or len(band_list) == 0: raise TypeError("band_list must be non empty list.") @@ -517,7 +509,7 @@ def dehaze(im, gamma=False): return m_a * I + m_b def _dehaze(im, r, w, maxatmo_mask, eps): - # im is RGB and range[0, 1] + # im is a RGB image and the value ranges in [0, 1]. atmo_mask = np.min(im, 2) dark_channel = cv2.erode(atmo_mask, np.ones((15, 15))) atmo_mask = _guided_filter(atmo_mask, dark_channel, r, eps) diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index 7204f27..fa8c4af 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -212,7 +212,7 @@ class DecodeImg(Transform): raise IOError('Can not open', img_path) im_data = dataset.ReadAsArray() if im_data.ndim == 2 and self.decode_sar: - im_data = to_intensity(im_data) # is read SAR + im_data = to_intensity(im_data) im_data = im_data[:, :, np.newaxis] else: if im_data.ndim == 3: @@ -1376,7 +1376,7 @@ class MixupImage(Transform): image = self.apply_im(sample[0]['image'], sample[1]['image'], factor) result = copy.deepcopy(sample[0]) result['image'] = image - # apply bbox and score + # Apply bbox and score if 'gt_bbox' in sample[0]: gt_bbox1 = sample[0]['gt_bbox'] gt_bbox2 = sample[1]['gt_bbox'] @@ -1469,7 +1469,7 @@ class RandomDistort(Transform): if np.random.uniform(0., 1.) < self.hue_prob: return image - # it works, but result differ from HSV version + # It works, but the result differs from HSV version. delta = np.random.uniform(low, high) u = np.cos(delta * np.pi) w = np.sin(delta * np.pi) @@ -1505,7 +1505,7 @@ class RandomDistort(Transform): for i in range(channel // 3): sub_img = image[:, :, 3 * i:3 * (i + 1)] sub_img = sub_img.astype(np.float32) - # it works, but result differ from HSV version + # It works, but the result differs from HSV version. gray = sub_img * np.array( [[[0.299, 0.587, 0.114]]], dtype=np.float32) gray = gray.sum(axis=2, keepdims=True) @@ -1720,9 +1720,9 @@ class _PadBox(Transform): if gt_num > 0: pad_score[:gt_num] = sample['gt_score'][:gt_num, 0] sample['gt_score'] = pad_score - # in training, for example in op ExpandImage, - # the bbox and gt_class is expanded, but the difficult is not, - # so, judging by it's length + # In training, for example in op ExpandImage, + # bbox and gt_class are expanded, but difficult is not, + # so judge by its length. if 'difficult' in sample: pad_diff = np.zeros((num_max, ), dtype=np.int32) if gt_num > 0: diff --git a/paddlers/utils/checkpoint.py b/paddlers/utils/checkpoint.py index 6948a04..82b5b12 100644 --- a/paddlers/utils/checkpoint.py +++ b/paddlers/utils/checkpoint.py @@ -441,7 +441,7 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None): if os.path.exists(pretrain_weights): param_state_dict = paddle.load(pretrain_weights) model_state_dict = model.state_dict() - # hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone' + # HACK: Fit for faster rcnn. Pretrain weights contain prefix of 'backbone' # while res5 module is located in bbox_head.head. Replace the prefix of # res5 with 'bbox_head.head' to load pretrain weights correctly. for k in param_state_dict.keys(): diff --git a/paddlers/utils/download.py b/paddlers/utils/download.py index 69340f4..1114960 100644 --- a/paddlers/utils/download.py +++ b/paddlers/utils/download.py @@ -98,7 +98,7 @@ def download(url, path, md5sum=None): # For protecting download interupted, download to # tmp_fullname firstly, move tmp_fullname to fullname - # after download finished + # after download finished. tmp_fullname = fullname + "_tmp" total_size = req.headers.get('content-length') with open(tmp_fullname, 'wb') as f: @@ -181,8 +181,7 @@ def download_and_decompress(url, path='.'): local_rank = paddle.distributed.get_rank() fname = osp.split(url)[-1] fullname = osp.join(path, fname) - # if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')): - # fullname = osp.join(path, fname.split('.')[0]) + if nranks <= 1: dst_dir = url2dir(url, path) if dst_dir is not None: diff --git a/paddlers/utils/utils.py b/paddlers/utils/utils.py index cc30be3..692a1c6 100644 --- a/paddlers/utils/utils.py +++ b/paddlers/utils/utils.py @@ -154,9 +154,9 @@ class DisablePrint(object): class Times(object): def __init__(self): self.time = 0. - # start time + # Start time self.st = 0. - # end time + # End time self.et = 0. def start(self): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index f5b24d4..47e008a 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -82,7 +82,7 @@ class _CommonTestNamespace: def wrapper(self, *args, **kwargs): with warnings.catch_warnings(record=True) as w: warnings.resetwarnings() - # ignore specified warnings + # Ignore specified warnings warning_white_list = [UserWarning] for warning in warning_white_list: warnings.simplefilter("ignore", warning) diff --git a/tools/coco2mask.py b/tools/coco2mask.py index fae007d..b6d4e6c 100644 --- a/tools/coco2mask.py +++ b/tools/coco2mask.py @@ -96,7 +96,7 @@ def convert_data(raw_dir, end_dir): shutil.copy(img_path, img_save_path) if k in anns.keys(): _save_mask(anns[k], sizes[k], lab_save_path) - else: # have not anns + else: _save_palette(np.zeros(sizes[k], dtype="uint8"), \ lab_save_path) diff --git a/tools/coco_tools/json_AnnoSta.py b/tools/coco_tools/json_AnnoSta.py index cf94c6c..94d119f 100644 --- a/tools/coco_tools/json_AnnoSta.py +++ b/tools/coco_tools/json_AnnoSta.py @@ -166,7 +166,7 @@ def get_args(): parser = argparse.ArgumentParser( description='Json Images Infomation Statistic') - # parameters + # Parameters parser.add_argument( '--json_path', type=str, diff --git a/tools/coco_tools/json_Img2Json.py b/tools/coco_tools/json_Img2Json.py index ea54c24..076aa25 100644 --- a/tools/coco_tools/json_Img2Json.py +++ b/tools/coco_tools/json_Img2Json.py @@ -63,7 +63,7 @@ def js_test(test_image_path, js_train_path, js_test_path, image_keyname, def get_args(): parser = argparse.ArgumentParser(description='Get Test Json') - # parameters + # Parameters parser.add_argument('--test_image_path', type=str, help='test image path') parser.add_argument( '--json_train_path', diff --git a/tools/coco_tools/json_ImgSta.py b/tools/coco_tools/json_ImgSta.py index 8ac86f1..a9e8752 100644 --- a/tools/coco_tools/json_ImgSta.py +++ b/tools/coco_tools/json_ImgSta.py @@ -74,7 +74,7 @@ def get_args(): parser = argparse.ArgumentParser( description='Json Images Infomation Statistic') - # parameters + # Parameters parser.add_argument( '--json_path', type=str, diff --git a/tools/coco_tools/json_InfoShow.py b/tools/coco_tools/json_InfoShow.py index 40e1964..436649e 100644 --- a/tools/coco_tools/json_InfoShow.py +++ b/tools/coco_tools/json_InfoShow.py @@ -51,7 +51,7 @@ def js_show(js_path, show_num): def get_args(): parser = argparse.ArgumentParser(description='Json Infomation Show') - # parameters + # Parameters parser.add_argument( '--json_path', type=str, help='json path to show information') parser.add_argument( diff --git a/tools/coco_tools/json_Merge.py b/tools/coco_tools/json_Merge.py index dfc863a..cf43914 100644 --- a/tools/coco_tools/json_Merge.py +++ b/tools/coco_tools/json_Merge.py @@ -56,7 +56,7 @@ def js_merge(js1_path, js2_path, js_merge_path, merge_keys): def get_args(): parser = argparse.ArgumentParser(description='Json Merge') - # parameters + # Parameters parser.add_argument('--json1_path', type=str, help='json path1 to merge') parser.add_argument('--json2_path', type=str, help='json path2 to merge') parser.add_argument( diff --git a/tools/coco_tools/json_Split.py b/tools/coco_tools/json_Split.py index 45342cf..a343eec 100644 --- a/tools/coco_tools/json_Split.py +++ b/tools/coco_tools/json_Split.py @@ -87,7 +87,7 @@ def js_split(js_all_path, js_train_path, js_val_path, val_split_rate, def get_args(): parser = argparse.ArgumentParser(description='Json Merge') - # parameters + # Parameters parser.add_argument('--json_all_path', type=str, help='json path to split') parser.add_argument( '--json_train_path', diff --git a/tools/raster2geotiff.py b/tools/raster2geotiff.py index d7614c2..25646b6 100644 --- a/tools/raster2geotiff.py +++ b/tools/raster2geotiff.py @@ -25,7 +25,8 @@ from utils import Raster, save_geotiff, translate_vector, time_it def _gt_convert(x_geo, y_geo, geotf): a = np.array([[geotf[1], geotf[2]], [geotf[4], geotf[5]]]) b = np.array([x_geo - geotf[0], y_geo - geotf[3]]) - return np.round(np.linalg.solve(a, b)).tolist() # 解一元二次方程 + return np.round(np.linalg.solve(a, + b)).tolist() # Solve a quadratic equation @time_it @@ -36,13 +37,13 @@ def convert_data(image_path, geojson_path): # vector to EPSG from raster temp_geojson_path = translate_vector(geojson_path, raster.proj) geo_reader = codecs.open(temp_geojson_path, "r", encoding="utf-8") - feats = geojson.loads(geo_reader.read())["features"] # 所有图像块 + feats = geojson.loads(geo_reader.read())["features"] # All image patches geo_reader.close() for feat in tqdm(feats): geo = feat["geometry"] - if geo["type"] == "Polygon": # 多边形 + if geo["type"] == "Polygon": geo_points = geo["coordinates"][0] - elif geo["type"] == "MultiPolygon": # 多面 + elif geo["type"] == "MultiPolygon": geo_points = geo["coordinates"][0][0] else: raise TypeError( @@ -52,7 +53,7 @@ def convert_data(image_path, geojson_path): _gt_convert(point[0], point[1], raster.geot) for point in geo_points ]).astype(np.int32) # TODO: Label category - cv2.fillPoly(tmp_img, [xy_points], 1) # 多边形填充 + cv2.fillPoly(tmp_img, [xy_points], 1) # Fill with polygons ext = "." + geojson_path.split(".")[-1] save_geotiff(tmp_img, geojson_path.replace(ext, ".tif"), raster.proj, raster.geot) diff --git a/tools/utils/__init__.py b/tools/utils/__init__.py index 8f499e8..065f7d7 100644 --- a/tools/utils/__init__.py +++ b/tools/utils/__init__.py @@ -14,7 +14,7 @@ import sys import os.path as osp -sys.path.insert(0, osp.abspath("..")) # add workspace +sys.path.insert(0, osp.abspath("..")) # Add workspace from .raster import Raster, raster2uint8, save_geotiff from .vector import translate_vector From 3f8ce38fb042e545beaa9b901f56e025c9869477 Mon Sep 17 00:00:00 2001 From: Lin Manhui Date: Wed, 17 Aug 2022 13:46:50 +0800 Subject: [PATCH 2/2] [Feat] Add Interface to Set Losses (#18) * Add interfaces to set losses * Fix typo * Fix import bugs * import custom_models->rs_models --- paddlers/models/__init__.py | 3 ++ paddlers/tasks/change_detector.py | 49 ++++++++++++++++++++++++------- paddlers/tasks/classifier.py | 35 +++++++++++++++++----- paddlers/tasks/segmenter.py | 42 +++++++++++++++++++------- 4 files changed, 101 insertions(+), 28 deletions(-) diff --git a/paddlers/models/__init__.py b/paddlers/models/__init__.py index 345e589..952821f 100644 --- a/paddlers/models/__init__.py +++ b/paddlers/models/__init__.py @@ -13,3 +13,6 @@ # limitations under the License. from . import ppcls, ppdet, ppseg, ppgan +import paddlers.models.ppseg.models.losses as seg_losses +import paddlers.models.ppdet.modeling.losses as det_losses +import paddlers.models.ppcls.loss as clas_losses diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index c6bd1e9..9d631e6 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -28,6 +28,7 @@ import paddlers import paddlers.models.ppseg as ppseg import paddlers.rs_models.cd as cmcd import paddlers.utils.logging as logging +from paddlers.models import seg_losses from paddlers.transforms import Resize, decode_image from paddlers.utils import get_single_card_bs, DisablePrint from paddlers.utils.checkpoint import seg_pretrain_weights_dict @@ -45,6 +46,7 @@ class BaseChangeDetector(BaseModel): model_name, num_classes=2, use_mixed_loss=False, + losses=None, **params): self.init_params = locals() if 'with_net' in self.init_params: @@ -56,7 +58,7 @@ class BaseChangeDetector(BaseModel): self.model_name = model_name self.num_classes = num_classes self.use_mixed_loss = use_mixed_loss - self.losses = None + self.losses = losses self.labels = None if params.get('with_net', True): params.pop('with_net', None) @@ -178,13 +180,13 @@ class BaseChangeDetector(BaseModel): if isinstance(self.use_mixed_loss, bool): if self.use_mixed_loss: losses = [ - ppseg.models.CrossEntropyLoss(), - ppseg.models.LovaszSoftmaxLoss() + seg_losses.CrossEntropyLoss(), + seg_losses.LovaszSoftmaxLoss() ] coef = [.8, .2] - loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ] + loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ] else: - loss_type = [ppseg.models.CrossEntropyLoss()] + loss_type = [seg_losses.CrossEntropyLoss()] else: losses, coef = list(zip(*self.use_mixed_loss)) if not set(losses).issubset( @@ -192,8 +194,8 @@ class BaseChangeDetector(BaseModel): raise ValueError( "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." ) - losses = [getattr(ppseg.models, loss)() for loss in losses] - loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))] + losses = [getattr(seg_losses, loss)() for loss in losses] + loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))] loss_coef = [1.0] losses = {'types': loss_type, 'coef': loss_coef} return losses @@ -810,11 +812,17 @@ class BaseChangeDetector(BaseModel): raise TypeError( "`transforms.arrange` must be an ArrangeChangeDetector object.") + def set_losses(self, losses, weights=None): + if weights is None: + weights = [1. for _ in range(len(losses))] + self.losses = {'types': losses, 'coef': weights} + class CDNet(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=6, **params): params.update({'in_channels': in_channels}) @@ -822,6 +830,7 @@ class CDNet(BaseChangeDetector): model_name='CDNet', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -829,6 +838,7 @@ class FCEarlyFusion(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=6, use_dropout=False, **params): @@ -837,6 +847,7 @@ class FCEarlyFusion(BaseChangeDetector): model_name='FCEarlyFusion', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -844,6 +855,7 @@ class FCSiamConc(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=3, use_dropout=False, **params): @@ -852,6 +864,7 @@ class FCSiamConc(BaseChangeDetector): model_name='FCSiamConc', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -859,6 +872,7 @@ class FCSiamDiff(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=3, use_dropout=False, **params): @@ -867,6 +881,7 @@ class FCSiamDiff(BaseChangeDetector): model_name='FCSiamDiff', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -874,6 +889,7 @@ class STANet(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=3, att_type='BAM', ds_factor=1, @@ -887,6 +903,7 @@ class STANet(BaseChangeDetector): model_name='STANet', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -894,6 +911,7 @@ class BIT(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=3, backbone='resnet18', n_stages=4, @@ -925,6 +943,7 @@ class BIT(BaseChangeDetector): model_name='BIT', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -932,6 +951,7 @@ class SNUNet(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=3, width=32, **params): @@ -940,6 +960,7 @@ class SNUNet(BaseChangeDetector): model_name='SNUNet', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -947,6 +968,7 @@ class DSIFN(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, use_dropout=False, **params): params.update({'use_dropout': use_dropout}) @@ -954,13 +976,14 @@ class DSIFN(BaseChangeDetector): model_name='DSIFN', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) def default_loss(self): if self.use_mixed_loss is False: return { # XXX: make sure the shallow copy works correctly here. - 'types': [ppseg.models.CrossEntropyLoss()] * 5, + 'types': [seg_losses.CrossEntropyLoss()] * 5, 'coef': [1.0] * 5 } else: @@ -973,6 +996,7 @@ class DSAMNet(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, in_channels=3, ca_ratio=8, sa_kernel=7, @@ -986,14 +1010,15 @@ class DSAMNet(BaseChangeDetector): model_name='DSAMNet', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) def default_loss(self): if self.use_mixed_loss is False: return { 'types': [ - ppseg.models.CrossEntropyLoss(), ppseg.models.DiceLoss(), - ppseg.models.DiceLoss() + seg_losses.CrossEntropyLoss(), seg_losses.DiceLoss(), + seg_losses.DiceLoss() ], 'coef': [1.0, 0.05, 0.05] } @@ -1007,6 +1032,7 @@ class ChangeStar(BaseChangeDetector): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, mid_channels=256, inner_channels=16, num_convs=4, @@ -1022,13 +1048,14 @@ class ChangeStar(BaseChangeDetector): model_name='ChangeStar', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) def default_loss(self): if self.use_mixed_loss is False: return { # XXX: make sure the shallow copy works correctly here. - 'types': [ppseg.models.CrossEntropyLoss()] * 4, + 'types': [seglosses.CrossEntropyLoss()] * 4, 'coef': [1.0] * 4 } else: diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index ad39a5d..1fdecd0 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -28,7 +28,7 @@ import paddlers.rs_models.clas as cmcls import paddlers.utils.logging as logging from paddlers.utils import get_single_card_bs, DisablePrint from paddlers.models.ppcls.metric import build_metrics -from paddlers.models.ppcls.loss import build_loss +from paddlers.models import clas_losses from paddlers.models.ppcls.data.postprocess import build_postprocess from paddlers.utils.checkpoint import cls_pretrain_weights_dict from paddlers.transforms import Resize, decode_image @@ -45,6 +45,7 @@ class BaseClassifier(BaseModel): in_channels=3, num_classes=2, use_mixed_loss=False, + losses=None, **params): self.init_params = locals() if 'with_net' in self.init_params: @@ -59,7 +60,7 @@ class BaseClassifier(BaseModel): self.num_classes = num_classes self.use_mixed_loss = use_mixed_loss self.metrics = None - self.losses = None + self.losses = losses self.labels = None self._postprocess = None if params.get('with_net', True): @@ -145,7 +146,7 @@ class BaseClassifier(BaseModel): def default_loss(self): # TODO: use mixed loss and other loss default_config = [{"CELoss": {"weight": 1.0}}] - return build_loss(default_config) + return clas_losses.build_loss(default_config) def default_optimizer(self, parameters, @@ -556,36 +557,56 @@ class BaseClassifier(BaseModel): class ResNet50_vd(BaseClassifier): - def __init__(self, num_classes=2, use_mixed_loss=False, **params): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + losses=None, + **params): super(ResNet50_vd, self).__init__( model_name='ResNet50_vd', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) class MobileNetV3_small_x1_0(BaseClassifier): - def __init__(self, num_classes=2, use_mixed_loss=False, **params): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + losses=None, + **params): super(MobileNetV3_small_x1_0, self).__init__( model_name='MobileNetV3_small_x1_0', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) class HRNet_W18_C(BaseClassifier): - def __init__(self, num_classes=2, use_mixed_loss=False, **params): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + losses=None, + **params): super(HRNet_W18_C, self).__init__( model_name='HRNet_W18_C', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) class CondenseNetV2_b(BaseClassifier): - def __init__(self, num_classes=2, use_mixed_loss=False, **params): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + losses=None, + **params): super(CondenseNetV2_b, self).__init__( model_name='CondenseNetV2_b', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index da17a5c..335d05c 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -26,10 +26,11 @@ from paddle.static import InputSpec import paddlers import paddlers.models.ppseg as ppseg import paddlers.rs_models.seg as cmseg -from paddlers.utils import get_single_card_bs, DisablePrint import paddlers.utils.logging as logging -from paddlers.utils.checkpoint import seg_pretrain_weights_dict +from paddlers.models import seg_losses from paddlers.transforms import Resize, decode_image +from paddlers.utils import get_single_card_bs, DisablePrint +from paddlers.utils.checkpoint import seg_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics @@ -41,6 +42,7 @@ class BaseSegmenter(BaseModel): model_name, num_classes=2, use_mixed_loss=False, + losses=None, **params): self.init_params = locals() if 'with_net' in self.init_params: @@ -53,7 +55,7 @@ class BaseSegmenter(BaseModel): self.model_name = model_name self.num_classes = num_classes self.use_mixed_loss = use_mixed_loss - self.losses = None + self.losses = losses self.labels = None if params.get('with_net', True): params.pop('with_net', None) @@ -160,13 +162,13 @@ class BaseSegmenter(BaseModel): if isinstance(self.use_mixed_loss, bool): if self.use_mixed_loss: losses = [ - ppseg.models.CrossEntropyLoss(), - ppseg.models.LovaszSoftmaxLoss() + seg_losses.CrossEntropyLoss(), + seg_losses.LovaszSoftmaxLoss() ] coef = [.8, .2] - loss_type = [ppseg.models.MixedLoss(losses=losses, coef=coef), ] + loss_type = [seg_losses.MixedLoss(losses=losses, coef=coef), ] else: - loss_type = [ppseg.models.CrossEntropyLoss()] + loss_type = [seg_losses.CrossEntropyLoss()] else: losses, coef = list(zip(*self.use_mixed_loss)) if not set(losses).issubset( @@ -174,8 +176,8 @@ class BaseSegmenter(BaseModel): raise ValueError( "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported." ) - losses = [getattr(ppseg.models, loss)() for loss in losses] - loss_type = [ppseg.models.MixedLoss(losses=losses, coef=list(coef))] + losses = [getattr(seg_losses, loss)() for loss in losses] + loss_type = [seg_losses.MixedLoss(losses=losses, coef=list(coef))] if self.model_name == 'FastSCNN': loss_type *= 2 loss_coef = [1.0, 0.4] @@ -771,12 +773,18 @@ class BaseSegmenter(BaseModel): raise TypeError( "`transforms.arrange` must be an ArrangeSegmenter object.") + def set_losses(self, losses, weights=None): + if weights is None: + weights = [1. for _ in range(len(losses))] + self.losses = {'types': losses, 'coef': weights} + class UNet(BaseSegmenter): def __init__(self, input_channel=3, num_classes=2, use_mixed_loss=False, + losses=None, use_deconv=False, align_corners=False, **params): @@ -789,6 +797,7 @@ class UNet(BaseSegmenter): input_channel=input_channel, num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -798,6 +807,7 @@ class DeepLabV3P(BaseSegmenter): num_classes=2, backbone='ResNet50_vd', use_mixed_loss=False, + losses=None, output_stride=8, backbone_indices=(0, 3), aspp_ratios=(1, 12, 24, 36), @@ -826,6 +836,7 @@ class DeepLabV3P(BaseSegmenter): model_name='DeepLabV3P', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -833,6 +844,7 @@ class FastSCNN(BaseSegmenter): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, align_corners=False, **params): params.update({'align_corners': align_corners}) @@ -840,6 +852,7 @@ class FastSCNN(BaseSegmenter): model_name='FastSCNN', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) @@ -848,6 +861,7 @@ class HRNet(BaseSegmenter): num_classes=2, width=48, use_mixed_loss=False, + losses=None, align_corners=False, **params): if width not in (18, 48): @@ -867,6 +881,7 @@ class HRNet(BaseSegmenter): model_name='FCN', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) self.model_name = 'HRNet' @@ -875,6 +890,7 @@ class BiSeNetV2(BaseSegmenter): def __init__(self, num_classes=2, use_mixed_loss=False, + losses=None, align_corners=False, **params): params.update({'align_corners': align_corners}) @@ -882,13 +898,19 @@ class BiSeNetV2(BaseSegmenter): model_name='BiSeNetV2', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params) class FarSeg(BaseSegmenter): - def __init__(self, num_classes=2, use_mixed_loss=False, **params): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + losses=None, + **params): super(FarSeg, self).__init__( model_name='FarSeg', num_classes=num_classes, use_mixed_loss=use_mixed_loss, + losses=losses, **params)