parent
0f2d3c1bfe
commit
4b4d2675b2
40 changed files with 1351 additions and 54 deletions
@ -1,5 +1,5 @@ |
||||
from . import builtin # ensure the builtin datasets are registered |
||||
# from .dataset_mapper import DatasetMapperWithBasis |
||||
from .dataset_mapper import DatasetMapperWithBasis |
||||
|
||||
|
||||
# __all__ = ["DatasetMapperWithBasis"] |
||||
__all__ = ["DatasetMapperWithBasis"] |
||||
|
@ -0,0 +1,141 @@ |
||||
import copy |
||||
import numpy as np |
||||
import torch |
||||
from fvcore.common.file_io import PathManager |
||||
from PIL import Image |
||||
|
||||
from detectron2.data.dataset_mapper import DatasetMapper |
||||
from detectron2.data.detection_utils import SizeMismatchError |
||||
from detectron2.data import detection_utils as utils |
||||
from detectron2.data import transforms as T |
||||
|
||||
""" |
||||
This file contains the default mapping that's applied to "dataset dicts". |
||||
""" |
||||
|
||||
__all__ = ["DatasetMapperWithBasis"] |
||||
|
||||
|
||||
class DatasetMapperWithBasis(DatasetMapper): |
||||
""" |
||||
This caller enables the default Detectron2 mapper to read an additional basis semantic label |
||||
""" |
||||
|
||||
def __init__(self, cfg, is_train=True): |
||||
super().__init__(cfg, is_train) |
||||
|
||||
# fmt: off |
||||
self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON |
||||
self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET |
||||
# fmt: on |
||||
|
||||
def __call__(self, dataset_dict): |
||||
""" |
||||
Args: |
||||
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. |
||||
|
||||
Returns: |
||||
dict: a format that builtin models in detectron2 accept |
||||
""" |
||||
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below |
||||
# USER: Write your own image loading if it's not from a file |
||||
try: |
||||
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) |
||||
except Exception as e: |
||||
print(dataset_dict["file_name"]) |
||||
print(e) |
||||
raise e |
||||
try: |
||||
utils.check_image_size(dataset_dict, image) |
||||
except SizeMismatchError as e: |
||||
expected_wh = (dataset_dict["width"], dataset_dict["height"]) |
||||
image_wh = (image.shape[1], image.shape[0]) |
||||
if (image_wh[1], image_wh[0]) == expected_wh: |
||||
print("transposing image {}".format(dataset_dict["file_name"])) |
||||
image = image.transpose(1, 0, 2) |
||||
else: |
||||
raise e |
||||
|
||||
if "annotations" not in dataset_dict or len(dataset_dict["annotations"]) == 0: |
||||
image, transforms = T.apply_transform_gens( |
||||
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image |
||||
) |
||||
else: |
||||
# Crop around an instance if there are instances in the image. |
||||
# USER: Remove if you don't use cropping |
||||
if self.crop_gen: |
||||
crop_tfm = utils.gen_crop_transform_with_instance( |
||||
self.crop_gen.get_crop_size(image.shape[:2]), |
||||
image.shape[:2], |
||||
np.random.choice(dataset_dict["annotations"]), |
||||
) |
||||
image = crop_tfm.apply_image(image) |
||||
image, transforms = T.apply_transform_gens(self.tfm_gens, image) |
||||
if self.crop_gen: |
||||
transforms = crop_tfm + transforms |
||||
|
||||
image_shape = image.shape[:2] # h, w |
||||
|
||||
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, |
||||
# but not efficient on large generic data structures due to the use of pickle & mp.Queue. |
||||
# Therefore it's important to use torch.Tensor. |
||||
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32")) |
||||
# Can use uint8 if it turns out to be slow some day |
||||
|
||||
# USER: Remove if you don't use pre-computed proposals. |
||||
if self.load_proposals: |
||||
utils.transform_proposals( |
||||
dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk |
||||
) |
||||
|
||||
if not self.is_train: |
||||
dataset_dict.pop("annotations", None) |
||||
dataset_dict.pop("sem_seg_file_name", None) |
||||
dataset_dict.pop("pano_seg_file_name", None) |
||||
return dataset_dict |
||||
|
||||
if "annotations" in dataset_dict: |
||||
# USER: Modify this if you want to keep them for some reason. |
||||
for anno in dataset_dict["annotations"]: |
||||
if not self.mask_on: |
||||
anno.pop("segmentation", None) |
||||
if not self.keypoint_on: |
||||
anno.pop("keypoints", None) |
||||
|
||||
# USER: Implement additional transformations if you have other types of data |
||||
annos = [ |
||||
utils.transform_instance_annotations( |
||||
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices |
||||
) |
||||
for obj in dataset_dict.pop("annotations") |
||||
if obj.get("iscrowd", 0) == 0 |
||||
] |
||||
instances = utils.annotations_to_instances( |
||||
annos, image_shape, mask_format=self.mask_format |
||||
) |
||||
# Create a tight bounding box from masks, useful when image is cropped |
||||
if self.crop_gen and instances.has("gt_masks"): |
||||
instances.gt_boxes = instances.gt_masks.get_bounding_boxes() |
||||
dataset_dict["instances"] = utils.filter_empty_instances(instances) |
||||
|
||||
# USER: Remove if you don't do semantic/panoptic segmentation. |
||||
if "sem_seg_file_name" in dataset_dict: |
||||
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f: |
||||
sem_seg_gt = Image.open(f) |
||||
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8") |
||||
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) |
||||
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) |
||||
dataset_dict["sem_seg"] = sem_seg_gt |
||||
|
||||
if self.basis_loss_on and self.is_train: |
||||
# load basis supervisions |
||||
if self.ann_set == "coco": |
||||
basis_sem_path = dataset_dict["file_name"].replace('train2017', 'thing_train2017').replace('image/train', 'thing_train') |
||||
else: |
||||
basis_sem_path = dataset_dict["file_name"].replace('coco', 'lvis').replace('train2017', 'thing_train').replace('jpg', 'npz') |
||||
basis_sem_path = basis_sem_path.replace('jpg', 'npz') |
||||
basis_sem_gt = np.load(basis_sem_path)["mask"] |
||||
basis_sem_gt = transforms.apply_segmentation(basis_sem_gt) |
||||
basis_sem_gt = torch.as_tensor(basis_sem_gt.astype("long")) |
||||
dataset_dict["basis_sem"] = basis_sem_gt |
||||
return dataset_dict |
@ -0,0 +1,2 @@ |
||||
from .basis_module import build_basis_module |
||||
from .blendmask import BlendMask |
@ -0,0 +1,104 @@ |
||||
from typing import Dict |
||||
from torch import nn |
||||
from torch.nn import functional as F |
||||
|
||||
from detectron2.utils.registry import Registry |
||||
from detectron2.layers import ShapeSpec |
||||
|
||||
from adet.layers import conv_with_kaiming_uniform |
||||
|
||||
|
||||
BASIS_MODULE_REGISTRY = Registry("BASIS_MODULE") |
||||
BASIS_MODULE_REGISTRY.__doc__ = """ |
||||
Registry for basis module, which produces global bases from feature maps. |
||||
|
||||
The registered object will be called with `obj(cfg, input_shape)`. |
||||
The call should return a `nn.Module` object. |
||||
""" |
||||
|
||||
|
||||
def build_basis_module(cfg, input_shape): |
||||
name = cfg.MODEL.BASIS_MODULE.NAME |
||||
return BASIS_MODULE_REGISTRY.get(name)(cfg, input_shape) |
||||
|
||||
|
||||
@BASIS_MODULE_REGISTRY.register() |
||||
class ProtoNet(nn.Module): |
||||
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): |
||||
""" |
||||
TODO: support deconv and variable channel width |
||||
""" |
||||
# official protonet has a relu after each conv |
||||
super().__init__() |
||||
# fmt: off |
||||
mask_dim = cfg.MODEL.BASIS_MODULE.NUM_BASES |
||||
planes = cfg.MODEL.BASIS_MODULE.CONVS_DIM |
||||
self.in_features = cfg.MODEL.BASIS_MODULE.IN_FEATURES |
||||
self.loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON |
||||
norm = cfg.MODEL.BASIS_MODULE.NORM |
||||
num_convs = cfg.MODEL.BASIS_MODULE.NUM_CONVS |
||||
self.visualize = cfg.MODEL.BLENDMASK.VISUALIZE |
||||
# fmt: on |
||||
|
||||
feature_channels = {k: v.channels for k, v in input_shape.items()} |
||||
|
||||
conv_block = conv_with_kaiming_uniform(norm, True) # conv relu bn |
||||
self.refine = nn.ModuleList() |
||||
for in_feature in self.in_features: |
||||
self.refine.append(conv_block( |
||||
feature_channels[in_feature], planes, 3, 1)) |
||||
tower = [] |
||||
for i in range(num_convs): |
||||
tower.append( |
||||
conv_block(planes, planes, 3, 1)) |
||||
tower.append( |
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) |
||||
tower.append( |
||||
conv_block(planes, planes, 3, 1)) |
||||
tower.append( |
||||
nn.Conv2d(planes, mask_dim, 1)) |
||||
self.add_module('tower', nn.Sequential(*tower)) |
||||
|
||||
if self.loss_on: |
||||
# fmt: off |
||||
self.common_stride = cfg.MODEL.BASIS_MODULE.COMMON_STRIDE |
||||
num_classes = cfg.MODEL.BASIS_MODULE.NUM_CLASSES + 1 |
||||
self.sem_loss_weight = cfg.MODEL.BASIS_MODULE.LOSS_WEIGHT |
||||
# fmt: on |
||||
|
||||
inplanes = feature_channels[self.in_features[0]] |
||||
self.seg_head = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=3, |
||||
stride=1, padding=1, bias=False), |
||||
nn.BatchNorm2d(planes), |
||||
nn.ReLU(), |
||||
nn.Conv2d(planes, planes, kernel_size=3, |
||||
stride=1, padding=1, bias=False), |
||||
nn.BatchNorm2d(planes), |
||||
nn.ReLU(), |
||||
nn.Conv2d(planes, num_classes, kernel_size=1, |
||||
stride=1)) |
||||
|
||||
def forward(self, features, targets=None): |
||||
for i, f in enumerate(self.in_features): |
||||
if i == 0: |
||||
x = self.refine[i](features[f]) |
||||
else: |
||||
x_p = self.refine[i](features[f]) |
||||
x_p = F.interpolate(x_p, x.size()[2:], mode="bilinear", align_corners=False) |
||||
# x_p = aligned_bilinear(x_p, x.size(3) // x_p.size(3)) |
||||
x = x + x_p |
||||
outputs = {"bases": [self.tower(x)]} |
||||
losses = {} |
||||
# auxiliary thing semantic loss |
||||
if self.training and self.loss_on: |
||||
sem_out = self.seg_head(features[self.in_features[0]]) |
||||
# resize target to reduce memory |
||||
gt_sem = targets.unsqueeze(1).float() |
||||
gt_sem = F.interpolate( |
||||
gt_sem, scale_factor=1 / self.common_stride) |
||||
seg_loss = F.cross_entropy( |
||||
sem_out, gt_sem.squeeze().long()) |
||||
losses['loss_basis_sem'] = seg_loss * self.sem_loss_weight |
||||
elif self.visualize and hasattr(self, "seg_head"): |
||||
outputs["seg_thing_out"] = self.seg_head(features[self.in_features[0]]) |
||||
return outputs, losses |
@ -0,0 +1,110 @@ |
||||
import torch |
||||
from torch.nn import functional as F |
||||
|
||||
from detectron2.layers import cat |
||||
from detectron2.modeling.poolers import ROIPooler |
||||
|
||||
|
||||
def build_blender(cfg): |
||||
return Blender(cfg) |
||||
|
||||
|
||||
class Blender(object): |
||||
def __init__(self, cfg): |
||||
|
||||
# fmt: off |
||||
self.pooler_resolution = cfg.MODEL.BLENDMASK.BOTTOM_RESOLUTION |
||||
sampling_ratio = cfg.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO |
||||
pooler_type = cfg.MODEL.BLENDMASK.POOLER_TYPE |
||||
pooler_scales = cfg.MODEL.BLENDMASK.POOLER_SCALES |
||||
self.attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE |
||||
self.top_interp = cfg.MODEL.BLENDMASK.TOP_INTERP |
||||
num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES |
||||
# fmt: on |
||||
|
||||
self.attn_len = num_bases * self.attn_size * self.attn_size |
||||
|
||||
self.pooler = ROIPooler( |
||||
output_size=self.pooler_resolution, |
||||
scales=pooler_scales, |
||||
sampling_ratio=sampling_ratio, |
||||
pooler_type=pooler_type, |
||||
canonical_level=2) |
||||
|
||||
def __call__(self, bases, proposals, gt_instances): |
||||
if gt_instances is not None: |
||||
# training |
||||
# reshape attns |
||||
extras = proposals["extras"] |
||||
attns = proposals["top_feats"] |
||||
pos_inds = extras["pos_inds"] |
||||
if pos_inds.numel() == 0: |
||||
return None, {"loss_mask": sum([x.sum() * 0 for x in attns]) + bases[0].sum() * 0} |
||||
|
||||
gt_inds = extras["gt_inds"] |
||||
attns = cat( |
||||
[ |
||||
# Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C) |
||||
x.permute(0, 2, 3, 1).reshape(-1, self.attn_len) |
||||
for x in attns |
||||
], dim=0,) |
||||
attns = attns[pos_inds] |
||||
|
||||
rois = self.pooler(bases, [x.gt_boxes for x in gt_instances]) |
||||
rois = rois[gt_inds] |
||||
pred_mask_logits = self.merge_bases(rois, attns) |
||||
|
||||
# gen targets |
||||
gt_masks = [] |
||||
for instances_per_image in gt_instances: |
||||
if len(instances_per_image.gt_boxes.tensor) == 0: |
||||
continue |
||||
gt_mask_per_image = instances_per_image.gt_masks.crop_and_resize( |
||||
instances_per_image.gt_boxes.tensor, self.pooler_resolution |
||||
).to(device=pred_mask_logits.device) |
||||
gt_masks.append(gt_mask_per_image) |
||||
gt_masks = cat(gt_masks, dim=0) |
||||
gt_masks = gt_masks[gt_inds] |
||||
N = gt_masks.size(0) |
||||
gt_masks = gt_masks.view(N, -1) |
||||
|
||||
gt_ctr = extras["gt_ctr"] |
||||
loss_denorm = extras["loss_denorm"] |
||||
mask_losses = F.binary_cross_entropy_with_logits( |
||||
pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="none") |
||||
mask_loss = ((mask_losses.mean(dim=-1) * gt_ctr).sum() |
||||
/ loss_denorm) |
||||
return None, {"loss_mask": mask_loss} |
||||
else: |
||||
# no proposals |
||||
total_instances = sum([len(x) for x in proposals]) |
||||
if total_instances == 0: |
||||
# add empty pred_masks results |
||||
for box in proposals: |
||||
box.pred_masks = box.pred_classes.view( |
||||
-1, 1, self.pooler_resolution, self.pooler_resolution) |
||||
return proposals, {} |
||||
rois = self.pooler(bases, [x.pred_boxes for x in proposals]) |
||||
attns = cat([x.top_feat for x in proposals], dim=0) |
||||
pred_mask_logits = self.merge_bases(rois, attns).sigmoid() |
||||
pred_mask_logits = pred_mask_logits.view( |
||||
-1, 1, self.pooler_resolution, self.pooler_resolution) |
||||
start_ind = 0 |
||||
for box in proposals: |
||||
end_ind = start_ind + len(box) |
||||
box.pred_masks = pred_mask_logits[start_ind:end_ind] |
||||
start_ind = end_ind |
||||
return proposals, {} |
||||
|
||||
def merge_bases(self, rois, coeffs, location_to_inds=None): |
||||
# merge predictions |
||||
N = coeffs.size(0) |
||||
if location_to_inds is not None: |
||||
rois = rois[location_to_inds] |
||||
N, B, H, W = rois.size() |
||||
|
||||
coeffs = coeffs.view(N, -1, self.attn_size, self.attn_size) |
||||
coeffs = F.interpolate(coeffs, (H, W), |
||||
mode=self.top_interp).softmax(dim=1) |
||||
masks_preds = (rois * coeffs).sum(dim=1) |
||||
return masks_preds.view(N, -1) |
@ -0,0 +1,154 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
||||
|
||||
import torch |
||||
from torch import nn |
||||
|
||||
from detectron2.structures import ImageList |
||||
from detectron2.modeling.postprocessing import detector_postprocess, sem_seg_postprocess |
||||
from detectron2.modeling.proposal_generator import build_proposal_generator |
||||
from detectron2.modeling.backbone import build_backbone |
||||
from detectron2.modeling.meta_arch.panoptic_fpn import combine_semantic_and_instance_outputs |
||||
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY |
||||
from detectron2.modeling.meta_arch.semantic_seg import build_sem_seg_head |
||||
|
||||
from .blender import build_blender |
||||
from .basis_module import build_basis_module |
||||
|
||||
__all__ = ["BlendMask"] |
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register() |
||||
class BlendMask(nn.Module): |
||||
""" |
||||
Main class for BlendMask architectures (see https://arxiv.org/abd/1901.02446). |
||||
""" |
||||
|
||||
def __init__(self, cfg): |
||||
super().__init__() |
||||
|
||||
self.device = torch.device(cfg.MODEL.DEVICE) |
||||
self.instance_loss_weight = cfg.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT |
||||
|
||||
self.backbone = build_backbone(cfg) |
||||
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) |
||||
self.blender = build_blender(cfg) |
||||
self.basis_module = build_basis_module(cfg, self.backbone.output_shape()) |
||||
|
||||
# options when combining instance & semantic outputs |
||||
self.combine_on = cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED |
||||
if self.combine_on: |
||||
self.panoptic_module = build_sem_seg_head(cfg, self.backbone.output_shape()) |
||||
self.combine_overlap_threshold = cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH |
||||
self.combine_stuff_area_limit = cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT |
||||
self.combine_instances_confidence_threshold = ( |
||||
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH) |
||||
|
||||
# build top module |
||||
in_channels = cfg.MODEL.FPN.OUT_CHANNELS |
||||
num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES |
||||
attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE |
||||
attn_len = num_bases * attn_size * attn_size |
||||
self.top_layer = nn.Conv2d( |
||||
in_channels, attn_len, |
||||
kernel_size=3, stride=1, padding=1) |
||||
torch.nn.init.normal_(self.top_layer.weight, std=0.01) |
||||
torch.nn.init.constant_(self.top_layer.bias, 0) |
||||
|
||||
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) |
||||
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) |
||||
self.normalizer = lambda x: (x - pixel_mean) / pixel_std |
||||
self.to(self.device) |
||||
|
||||
def forward(self, batched_inputs): |
||||
""" |
||||
Args: |
||||
batched_inputs: a list, batched outputs of :class:`DatasetMapper`. |
||||
Each item in the list contains the inputs for one image. |
||||
|
||||
For now, each item in the list is a dict that contains: |
||||
image: Tensor, image in (C, H, W) format. |
||||
instances: Instances |
||||
sem_seg: semantic segmentation ground truth. |
||||
Other information that's included in the original dicts, such as: |
||||
"height", "width" (int): the output resolution of the model, used in inference. |
||||
See :meth:`postprocess` for details. |
||||
|
||||
Returns: |
||||
list[dict]: each dict is the results for one image. The dict |
||||
contains the following keys: |
||||
"instances": see :meth:`GeneralizedRCNN.forward` for its format. |
||||
"sem_seg": see :meth:`SemanticSegmentor.forward` for its format. |
||||
"panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`. |
||||
See the return value of |
||||
:func:`combine_semantic_and_instance_outputs` for its format. |
||||
""" |
||||
images = [x["image"].to(self.device) for x in batched_inputs] |
||||
images = [self.normalizer(x) for x in images] |
||||
images = ImageList.from_tensors(images, self.backbone.size_divisibility) |
||||
features = self.backbone(images.tensor) |
||||
|
||||
if self.combine_on: |
||||
if "sem_seg" in batched_inputs[0]: |
||||
gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs] |
||||
gt_sem = ImageList.from_tensors( |
||||
gt_sem, self.backbone.size_divisibility, self.panoptic_module.ignore_value |
||||
).tensor |
||||
else: |
||||
gt_sem = None |
||||
sem_seg_results, sem_seg_losses = self.panoptic_module(features, gt_sem) |
||||
|
||||
if "basis_sem" in batched_inputs[0]: |
||||
basis_sem = [x["basis_sem"].to(self.device) for x in batched_inputs] |
||||
basis_sem = ImageList.from_tensors( |
||||
basis_sem, self.backbone.size_divisibility, 0).tensor |
||||
else: |
||||
basis_sem = None |
||||
basis_out, basis_losses = self.basis_module(features, basis_sem) |
||||
|
||||
if "instances" in batched_inputs[0]: |
||||
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] |
||||
else: |
||||
gt_instances = None |
||||
proposals, proposal_losses = self.proposal_generator( |
||||
images, features, gt_instances, self.top_layer) |
||||
detector_results, detector_losses = self.blender( |
||||
basis_out["bases"], proposals, gt_instances) |
||||
|
||||
if self.training: |
||||
losses = {} |
||||
losses.update(basis_losses) |
||||
losses.update({k: v * self.instance_loss_weight for k, v in detector_losses.items()}) |
||||
losses.update(proposal_losses) |
||||
if self.combine_on: |
||||
losses.update(sem_seg_losses) |
||||
return losses |
||||
|
||||
processed_results = [] |
||||
for i, (detector_result, input_per_image, image_size) in enumerate(zip( |
||||
detector_results, batched_inputs, images.image_sizes)): |
||||
height = input_per_image.get("height", image_size[0]) |
||||
width = input_per_image.get("width", image_size[1]) |
||||
detector_r = detector_postprocess(detector_result, height, width) |
||||
processed_result = {"instances": detector_r} |
||||
if self.combine_on: |
||||
sem_seg_r = sem_seg_postprocess( |
||||
sem_seg_results[i], image_size, height, width) |
||||
processed_result["sem_seg"] = sem_seg_r |
||||
if "seg_thing_out" in basis_out: |
||||
seg_thing_r = sem_seg_postprocess( |
||||
basis_out["seg_thing_out"], image_size, height, width) |
||||
processed_result["sem_thing_seg"] = seg_thing_r |
||||
if self.basis_module.visualize: |
||||
processed_result["bases"] = basis_out["bases"] |
||||
processed_results.append(processed_result) |
||||
|
||||
if self.combine_on: |
||||
panoptic_r = combine_semantic_and_instance_outputs( |
||||
detector_r, |
||||
sem_seg_r.argmax(dim=0), |
||||
self.combine_overlap_threshold, |
||||
self.combine_stuff_area_limit, |
||||
self.combine_instances_confidence_threshold) |
||||
processed_results[-1]["panoptic_seg"] = panoptic_r |
||||
return processed_results |
@ -0,0 +1,6 @@ |
||||
_BASE_: "Base-550.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
OUTPUT_DIR: "output/blendmask/550_R_50_1x" |
@ -0,0 +1,9 @@ |
||||
_BASE_: "Base-550.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
OUTPUT_DIR: "output/blendmask/550_R_50_3x" |
@ -0,0 +1,18 @@ |
||||
_BASE_: "Base-550.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
DEFORM_ON_PER_STAGE: [False, True, True, True] |
||||
DEFORM_MODULATED: True |
||||
DEFORM_INTERVAL: 3 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (440, 594) |
||||
MIN_SIZE_TRAIN_SAMPLING: "range" |
||||
MAX_SIZE_TRAIN: 990 |
||||
CROP: |
||||
ENABLED: True |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
OUTPUT_DIR: "output/blendmask/550_R_50_dcni3_5x" |
@ -0,0 +1,17 @@ |
||||
_BASE_: "Base-BlendMask.yaml" |
||||
MODEL: |
||||
FCOS: |
||||
TOP_LEVELS: 1 |
||||
IN_FEATURES: ["p3", "p4", "p5", "p6"] |
||||
FPN_STRIDES: [8, 16, 32, 64] |
||||
SIZES_OF_INTEREST: [64, 128, 256] |
||||
NUM_SHARE_CONVS: 3 |
||||
NUM_CLS_CONVS: 0 |
||||
NUM_BOX_CONVS: 0 |
||||
BASIS_MODULE: |
||||
NUM_CONVS: 2 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (440, 462, 484, 506, 528, 550) |
||||
MAX_SIZE_TRAIN: 916 |
||||
MIN_SIZE_TEST: 550 |
||||
MAX_SIZE_TEST: 916 |
@ -0,0 +1,29 @@ |
||||
MODEL: |
||||
META_ARCHITECTURE: "BlendMask" |
||||
MASK_ON: True |
||||
BACKBONE: |
||||
NAME: "build_fcos_resnet_fpn_backbone" |
||||
RESNETS: |
||||
OUT_FEATURES: ["res3", "res4", "res5"] |
||||
FPN: |
||||
IN_FEATURES: ["res3", "res4", "res5"] |
||||
PROPOSAL_GENERATOR: |
||||
NAME: "FCOS" |
||||
BASIS_MODULE: |
||||
LOSS_ON: True |
||||
PANOPTIC_FPN: |
||||
COMBINE: |
||||
ENABLED: False |
||||
FCOS: |
||||
THRESH_WITH_CTR: True |
||||
USE_SCALE: False |
||||
DATASETS: |
||||
TRAIN: ("coco_2017_train",) |
||||
TEST: ("coco_2017_val",) |
||||
SOLVER: |
||||
IMS_PER_BATCH: 16 |
||||
BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate |
||||
STEPS: (60000, 80000) |
||||
MAX_ITER: 90000 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) |
@ -0,0 +1,16 @@ |
||||
_BASE_: "../Base-BlendMask.yaml" |
||||
MODEL: |
||||
RESNETS: |
||||
OUT_FEATURES: ["res2", "res3", "res4", "res5"] |
||||
FPN: |
||||
IN_FEATURES: ["res2", "res3", "res4", "res5"] |
||||
SEM_SEG_HEAD: |
||||
LOSS_WEIGHT: 0.5 |
||||
PANOPTIC_FPN: |
||||
COMBINE: |
||||
ENABLED: True |
||||
INSTANCES_CONFIDENCE_THRESH: 0.2 |
||||
OVERLAP_THRESH: 0.4 |
||||
DATASETS: |
||||
TRAIN: ("coco_2017_train_panoptic_separated",) |
||||
TEST: ("coco_2017_val_panoptic_separated",) |
@ -0,0 +1,9 @@ |
||||
_BASE_: "Base-Panoptic.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" |
||||
RESNETS: |
||||
DEPTH: 101 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
OUTPUT_DIR: "output/panoptic/blendmask/R_101_3x" |
@ -0,0 +1,18 @@ |
||||
_BASE_: "Base-Panoptic.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" |
||||
RESNETS: |
||||
DEPTH: 101 |
||||
DEFORM_ON_PER_STAGE: [False, True, True, True] |
||||
DEFORM_MODULATED: True |
||||
DEFORM_INTERVAL: 3 |
||||
SOLVER: |
||||
STEPS: (280000, 360000) |
||||
MAX_ITER: 400000 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (640, 864) |
||||
MIN_SIZE_TRAIN_SAMPLING: "range" |
||||
MAX_SIZE_TRAIN: 1333 |
||||
CROP: |
||||
ENABLED: True |
||||
OUTPUT_DIR: "output/panoptic/blendmask/R_101_dcni3_5x" |
@ -0,0 +1,6 @@ |
||||
_BASE_: "Base-Panoptic.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
OUTPUT_DIR: "output/panoptic/blendmask/R_50_1x" |
@ -0,0 +1,9 @@ |
||||
_BASE_: "Base-Panoptic.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
OUTPUT_DIR: "output/panoptic/blendmask/R_50_3x" |
@ -0,0 +1,18 @@ |
||||
_BASE_: "Base-Panoptic.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
DEFORM_ON_PER_STAGE: [False, True, True, True] |
||||
DEFORM_MODULATED: True |
||||
DEFORM_INTERVAL: 3 |
||||
SOLVER: |
||||
STEPS: (280000, 360000) |
||||
MAX_ITER: 400000 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (640, 864) |
||||
MIN_SIZE_TRAIN_SAMPLING: "range" |
||||
MAX_SIZE_TRAIN: 1440 |
||||
CROP: |
||||
ENABLED: True |
||||
OUTPUT_DIR: "output/panoptic/blendmask/R_50_dcni3_5x" |
@ -0,0 +1,9 @@ |
||||
_BASE_: "../Base-BlendMask.yaml" |
||||
MODEL: |
||||
BASIS_MODULE: |
||||
NUM_CLASSES: 1 |
||||
FCOS: |
||||
NUM_CLASSES: 1 |
||||
DATASETS: |
||||
TRAIN: ("pic_person_train",) |
||||
TEST: ("pic_person_val",) |
@ -0,0 +1,6 @@ |
||||
_BASE_: "Base-Person.yaml" |
||||
MODEL: |
||||
WEIGHTS: "https://cloudstor.aarnet.edu.au/plus/s/9u1cG2zXvEva5SM/download#R_50_3x.pth" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
OUTPUT_DIR: "output/person/blendmask/R_50_1x" |
@ -0,0 +1,9 @@ |
||||
_BASE_: "Base-BlendMask.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" |
||||
RESNETS: |
||||
DEPTH: 101 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
OUTPUT_DIR: "output/blendmask/R_101_3x" |
@ -0,0 +1,20 @@ |
||||
_BASE_: "Base-BlendMask.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" |
||||
RESNETS: |
||||
DEPTH: 101 |
||||
DEFORM_ON_PER_STAGE: [False, True, True, True] |
||||
DEFORM_MODULATED: True |
||||
DEFORM_INTERVAL: 3 |
||||
SOLVER: |
||||
STEPS: (280000, 360000) |
||||
MAX_ITER: 400000 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (640, 864) |
||||
MIN_SIZE_TRAIN_SAMPLING: "range" |
||||
MAX_SIZE_TRAIN: 1440 |
||||
CROP: |
||||
ENABLED: True |
||||
TEST: |
||||
EVAL_PERIOD: 20000 |
||||
OUTPUT_DIR: "output/blendmask/R_101_dcni3_5x" |
@ -0,0 +1,6 @@ |
||||
_BASE_: "Base-BlendMask.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
OUTPUT_DIR: "output/blendmask/R_50_1x" |
@ -0,0 +1,9 @@ |
||||
_BASE_: "Base-BlendMask.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
OUTPUT_DIR: "output/blendmask/R_50_3x" |
@ -0,0 +1,15 @@ |
||||
_BASE_: "Base-RCNN.yaml" |
||||
MODEL: |
||||
WEIGHTS: "output/mask_rcnn/550_R_50_3x/model_final.pth" |
||||
MASK_ON: True |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (440, 462, 484, 506, 528, 550) |
||||
MAX_SIZE_TRAIN: 916 |
||||
MIN_SIZE_TEST: 550 |
||||
MAX_SIZE_TEST: 916 |
||||
OUTPUT_DIR: "output/mask_rcnn/550_R_50_3x" |
@ -0,0 +1,42 @@ |
||||
MODEL: |
||||
META_ARCHITECTURE: "GeneralizedRCNN" |
||||
BACKBONE: |
||||
NAME: "build_resnet_fpn_backbone" |
||||
RESNETS: |
||||
OUT_FEATURES: ["res2", "res3", "res4", "res5"] |
||||
FPN: |
||||
IN_FEATURES: ["res2", "res3", "res4", "res5"] |
||||
ANCHOR_GENERATOR: |
||||
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map |
||||
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) |
||||
RPN: |
||||
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] |
||||
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level |
||||
PRE_NMS_TOPK_TEST: 1000 # Per FPN level |
||||
# Detectron1 uses 2000 proposals per-batch, |
||||
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) |
||||
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2. |
||||
POST_NMS_TOPK_TRAIN: 1000 |
||||
POST_NMS_TOPK_TEST: 1000 |
||||
ROI_HEADS: |
||||
NAME: "StandardROIHeads" |
||||
IN_FEATURES: ["p2", "p3", "p4", "p5"] |
||||
ROI_BOX_HEAD: |
||||
NAME: "FastRCNNConvFCHead" |
||||
NUM_FC: 2 |
||||
POOLER_RESOLUTION: 7 |
||||
ROI_MASK_HEAD: |
||||
NAME: "MaskRCNNConvUpsampleHead" |
||||
NUM_CONV: 4 |
||||
POOLER_RESOLUTION: 14 |
||||
DATASETS: |
||||
TRAIN: ("coco_2017_train",) |
||||
TEST: ("coco_2017_val",) |
||||
SOLVER: |
||||
IMS_PER_BATCH: 16 |
||||
BASE_LR: 0.02 |
||||
STEPS: (60000, 80000) |
||||
MAX_ITER: 90000 |
||||
INPUT: |
||||
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) |
||||
VERSION: 2 |
@ -0,0 +1,6 @@ |
||||
_BASE_: "Base-LVIS.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" |
||||
RESNETS: |
||||
DEPTH: 50 |
||||
OUTPUT_DIR: "output/lvis/mask_rcnn/R_50_1x" |
@ -0,0 +1,9 @@ |
||||
_BASE_: "Base-RCNN.yaml" |
||||
MODEL: |
||||
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" |
||||
MASK_ON: True |
||||
RESNETS: |
||||
DEPTH: 101 |
||||
SOLVER: |
||||
STEPS: (210000, 250000) |
||||
MAX_ITER: 270000 |
@ -0,0 +1,33 @@ |
||||
## Expected dataset structure for AdelaiDet instance detection: |
||||
|
||||
``` |
||||
coco/ |
||||
thing_train2017/ |
||||
# thing class label maps for auxiliary semantic loss |
||||
lvis/ |
||||
thing_train/ |
||||
# semantic labels for LVIS |
||||
``` |
||||
|
||||
Run `python prepare_thing_sem_from_instance.py`, to extract semantic labels from instance annotations. |
||||
|
||||
Run `python prepare_thing_sem_from_lvis.py`, to extract semantic labels from LVIS annotations. |
||||
|
||||
## Expected dataset structure for Person In Context instance detection: |
||||
|
||||
``` |
||||
pic/ |
||||
thing_train/ |
||||
# thing class label maps for auxiliary semantic loss |
||||
annotations/ |
||||
train_person.json |
||||
val_person.json |
||||
image/ |
||||
train/ |
||||
... |
||||
|
||||
``` |
||||
|
||||
First link the PIC_2.0 dataset to this folder with `ln -s \path\to\PIC_2.0 pic`. Then use the `python gen_coco_person.py` to generate train and validation annotation jsons. |
||||
|
||||
Run `python prepare_thing_sem_from_instance.py --dataset-name pic`, to extract semantic labels from instance annotations. |
@ -0,0 +1,101 @@ |
||||
import numpy as np |
||||
import cv2 |
||||
import os |
||||
import json |
||||
error_list = ['23382.png', '23441.png', '20714.png', '20727.png', '23300.png', '21200.png'] |
||||
|
||||
def mask2box(mask): |
||||
index = np.argwhere(mask == 1) |
||||
rows = index[:, 0] |
||||
clos = index[:, 1] |
||||
y1 = int(np.min(rows)) # y |
||||
x1 = int(np.min(clos)) # x |
||||
y2 = int(np.max(rows)) |
||||
x2 = int(np.max(clos)) |
||||
return (x1, y1, x2, y2) |
||||
|
||||
def gen_coco(phase): |
||||
result = { |
||||
"info": {"description": "PIC2.0 dataset."}, |
||||
"categories": [ |
||||
{"supercategory": "none", "id": 1, "name": "person"} |
||||
] |
||||
} |
||||
out_json = phase +'_person.json' |
||||
store_segmentation = True |
||||
|
||||
images_info = [] |
||||
labels_info = [] |
||||
img_id = 0 |
||||
files = tuple(open("pic/list5/"+phase+'_id', 'r')) |
||||
files = (_.strip() for _ in files) |
||||
|
||||
for index, image_name in enumerate(files): |
||||
image_name = image_name+".png" |
||||
print(index, image_name) |
||||
if image_name in error_list: |
||||
continue |
||||
instance = cv2.imread(os.path.join('instance', phase, image_name), flags=cv2.IMREAD_GRAYSCALE) |
||||
semantic = cv2.imread(os.path.join('semantic', phase, image_name), flags=cv2.IMREAD_GRAYSCALE) |
||||
# print(instance.shape, semantic.shape) |
||||
h = instance.shape[0] |
||||
w = instance.shape[1] |
||||
images_info.append( |
||||
{ |
||||
"file_name": image_name[:-4]+'.jpg', |
||||
"height": h, |
||||
"width": w, |
||||
"id": index |
||||
} |
||||
) |
||||
instance_max_num = instance.max() |
||||
instance_ids = np.unique(instance) |
||||
for instance_id in instance_ids: |
||||
if instance_id == 0: |
||||
continue |
||||
instance_part = instance == instance_id |
||||
object_pos = instance_part.nonzero() |
||||
# category_id_ = int(semantic[object_pos[0][0], object_pos[1][0]]) |
||||
category_id = int(np.max(semantic[object_pos[0], object_pos[1]])) |
||||
# assert category_id_ == category_id, (category_id_, category_id) |
||||
if category_id != 1: |
||||
continue |
||||
area = int(instance_part.sum()) |
||||
x1, y1, x2, y2 = mask2box(instance_part) |
||||
w = x2 - x1 + 1 |
||||
h = y2 - y1 + 1 |
||||
segmentation = [] |
||||
if store_segmentation: |
||||
contours, hierarchy = cv2.findContours((instance_part * 255).astype(np.uint8), cv2.RETR_TREE, |
||||
cv2.CHAIN_APPROX_SIMPLE) |
||||
for contour in contours: |
||||
contour = contour.flatten().tolist() |
||||
if len(contour) > 4: |
||||
segmentation.append(contour) |
||||
if len(segmentation) == 0: |
||||
print('error') |
||||
continue |
||||
labels_info.append( |
||||
{ |
||||
"segmentation": segmentation, # poly |
||||
"area": area, # segmentation area |
||||
"iscrowd": 0, |
||||
"image_id": index, |
||||
"bbox": [x1, y1, w, h], |
||||
"category_id": category_id, |
||||
"id": img_id |
||||
}, |
||||
) |
||||
img_id += 1 |
||||
# break |
||||
result["images"] = images_info |
||||
result["annotations"] = labels_info |
||||
with open('pic/annotations/' + out_json, 'w') as f: |
||||
json.dump(result, f, indent=4) |
||||
|
||||
if __name__ == "__main__": |
||||
if not os.path.exists('pic/annotations/'): |
||||
os.mkdirs('pic/annotations/') |
||||
gen_coco("train") |
||||
gen_coco("val") |
||||
#gen_coco("test") |
@ -0,0 +1,125 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
||||
|
||||
import time |
||||
import functools |
||||
import multiprocessing as mp |
||||
import numpy as np |
||||
import os |
||||
import argparse |
||||
from pycocotools.coco import COCO |
||||
from pycocotools import mask as maskUtils |
||||
|
||||
from detectron2.data.datasets.builtin_meta import _get_coco_instances_meta |
||||
|
||||
|
||||
def annToRLE(ann, img_size): |
||||
h, w = img_size |
||||
segm = ann['segmentation'] |
||||
if type(segm) == list: |
||||
# polygon -- a single object might consist of multiple parts |
||||
# we merge all parts into one mask rle code |
||||
rles = maskUtils.frPyObjects(segm, h, w) |
||||
rle = maskUtils.merge(rles) |
||||
elif type(segm['counts']) == list: |
||||
# uncompressed RLE |
||||
rle = maskUtils.frPyObjects(segm, h, w) |
||||
else: |
||||
# rle |
||||
rle = ann['segmentation'] |
||||
return rle |
||||
|
||||
|
||||
def annToMask(ann, img_size): |
||||
rle = annToRLE(ann, img_size) |
||||
m = maskUtils.decode(rle) |
||||
return m |
||||
|
||||
|
||||
def _process_instance_to_semantic(anns, output_semantic, img, categories): |
||||
img_size = (img["height"], img["width"]) |
||||
output = np.zeros(img_size, dtype=np.uint8) |
||||
for ann in anns: |
||||
mask = annToMask(ann, img_size) |
||||
output[mask == 1] = categories[ann["category_id"]] + 1 |
||||
# save as compressed npz |
||||
np.savez_compressed(output_semantic, mask=output) |
||||
# Image.fromarray(output).save(output_semantic) |
||||
|
||||
|
||||
def create_coco_semantic_from_instance(instance_json, sem_seg_root, categories): |
||||
""" |
||||
Create semantic segmentation annotations from panoptic segmentation |
||||
annotations, to be used by PanopticFPN. |
||||
|
||||
It maps all thing categories to contiguous ids starting from 1, and maps all unlabeled pixels to class 0 |
||||
|
||||
Args: |
||||
instance_json (str): path to the instance json file, in COCO's format. |
||||
sem_seg_root (str): a directory to output semantic annotation files |
||||
categories (dict): category metadata. Each dict needs to have: |
||||
"id": corresponds to the "category_id" in the json annotations |
||||
"isthing": 0 or 1 |
||||
""" |
||||
os.makedirs(sem_seg_root, exist_ok=True) |
||||
|
||||
coco_detection = COCO(instance_json) |
||||
|
||||
def iter_annotations(): |
||||
for img_id in coco_detection.getImgIds(): |
||||
anns_ids = coco_detection.getAnnIds(img_id) |
||||
anns = coco_detection.loadAnns(anns_ids) |
||||
img = coco_detection.loadImgs(int(img_id))[0] |
||||
output = os.path.join(sem_seg_root, img["file_name"].replace('jpg', 'npz')) |
||||
yield anns, output, img |
||||
|
||||
# single process |
||||
# print("Start writing to {} ...".format(sem_seg_root)) |
||||
# start = time.time() |
||||
# for anno, oup, img in iter_annotations(): |
||||
# _process_instance_to_semantic( |
||||
# anno, oup, img, categories) |
||||
# print("Finished. time: {:.2f}s".format(time.time() - start)) |
||||
# return |
||||
|
||||
pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) |
||||
|
||||
print("Start writing to {} ...".format(sem_seg_root)) |
||||
start = time.time() |
||||
pool.starmap( |
||||
functools.partial( |
||||
_process_instance_to_semantic, |
||||
categories=categories), |
||||
iter_annotations(), |
||||
chunksize=100, |
||||
) |
||||
print("Finished. time: {:.2f}s".format(time.time() - start)) |
||||
|
||||
|
||||
def get_parser(): |
||||
parser = argparse.ArgumentParser(description="Keep only model in ckpt") |
||||
parser.add_argument( |
||||
"--dataset-name", |
||||
default="coco", |
||||
help="dataset to generate", |
||||
) |
||||
return parser |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
args = get_parser().parse_args() |
||||
dataset_dir = os.path.join(os.path.dirname(__file__), args.dataset_name) |
||||
if args.dataset_name == "coco": |
||||
thing_id_to_contiguous_id = _get_coco_instances_meta()["thing_dataset_id_to_contiguous_id"] |
||||
split_name = 'train2017' |
||||
annotation_name = "annotations/instances_{}.json" |
||||
else: |
||||
thing_id_to_contiguous_id = {1: 0} |
||||
split_name = 'train' |
||||
annotation_name = "annotations/{}_person.json" |
||||
for s in ["train2017"]: |
||||
create_coco_semantic_from_instance( |
||||
os.path.join(dataset_dir, "annotations/instances_{}.json".format(s)), |
||||
os.path.join(dataset_dir, "thing_{}".format(s)), |
||||
thing_id_to_contiguous_id |
||||
) |
@ -0,0 +1,98 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
||||
|
||||
import time |
||||
import functools |
||||
import multiprocessing as mp |
||||
import numpy as np |
||||
import os |
||||
from lvis import LVIS |
||||
from pycocotools import mask as maskUtils |
||||
|
||||
|
||||
def annToRLE(ann, img_size): |
||||
h, w = img_size |
||||
segm = ann['segmentation'] |
||||
if type(segm) == list: |
||||
# polygon -- a single object might consist of multiple parts |
||||
# we merge all parts into one mask rle code |
||||
rles = maskUtils.frPyObjects(segm, h, w) |
||||
rle = maskUtils.merge(rles) |
||||
elif type(segm['counts']) == list: |
||||
# uncompressed RLE |
||||
rle = maskUtils.frPyObjects(segm, h, w) |
||||
else: |
||||
# rle |
||||
rle = ann['segmentation'] |
||||
return rle |
||||
|
||||
|
||||
def annToMask(ann, img_size): |
||||
rle = annToRLE(ann, img_size) |
||||
m = maskUtils.decode(rle) |
||||
return m |
||||
|
||||
|
||||
def _process_instance_to_semantic(anns, output_semantic, img): |
||||
img_size = (img["height"], img["width"]) |
||||
output = np.zeros(img_size, dtype=np.uint8) |
||||
for ann in anns: |
||||
mask = annToMask(ann, img_size) |
||||
output[mask == 1] = ann["category_id"] // 5 |
||||
# save as compressed npz |
||||
np.savez_compressed(output_semantic, mask=output) |
||||
# Image.fromarray(output).save(output_semantic) |
||||
|
||||
|
||||
def create_lvis_semantic_from_instance(instance_json, sem_seg_root): |
||||
""" |
||||
Create semantic segmentation annotations from panoptic segmentation |
||||
annotations, to be used by PanopticFPN. |
||||
|
||||
It maps all thing categories to contiguous ids starting from 1, and maps all unlabeled pixels to class 0 |
||||
|
||||
Args: |
||||
instance_json (str): path to the instance json file, in COCO's format. |
||||
sem_seg_root (str): a directory to output semantic annotation files |
||||
""" |
||||
os.makedirs(sem_seg_root, exist_ok=True) |
||||
|
||||
lvis_detection = LVIS(instance_json) |
||||
|
||||
def iter_annotations(): |
||||
for img_id in lvis_detection.get_img_ids(): |
||||
anns_ids = lvis_detection.get_ann_ids([img_id]) |
||||
anns = lvis_detection.load_anns(anns_ids) |
||||
img = lvis_detection.load_imgs([img_id])[0] |
||||
output = os.path.join(sem_seg_root, img["file_name"].replace('jpg', 'npz')) |
||||
yield anns, output, img |
||||
|
||||
# # single process |
||||
# print("Start writing to {} ...".format(sem_seg_root)) |
||||
# start = time.time() |
||||
# for anno, oup, img in iter_annotations(): |
||||
# _process_instance_to_semantic( |
||||
# anno, oup, img) |
||||
# print("Finished. time: {:.2f}s".format(time.time() - start)) |
||||
# return |
||||
|
||||
pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) |
||||
|
||||
print("Start writing to {} ...".format(sem_seg_root)) |
||||
start = time.time() |
||||
pool.starmap( |
||||
functools.partial( |
||||
_process_instance_to_semantic), |
||||
iter_annotations(), |
||||
chunksize=100, |
||||
) |
||||
print("Finished. time: {:.2f}s".format(time.time() - start)) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
dataset_dir = os.path.join(os.path.dirname(__file__), "lvis") |
||||
for s in ["train"]: |
||||
create_lvis_semantic_from_instance( |
||||
os.path.join(dataset_dir, "lvis_v0.5_{}.json".format(s)), |
||||
os.path.join(dataset_dir, "thing_{}".format(s)), |
||||
) |
Loading…
Reference in new issue