You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
5.0 KiB
135 lines
5.0 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import os |
|
import numpy as np |
|
|
|
from paddlers.models.ppseg.datasets import Dataset |
|
from paddlers.models.ppseg.cvlibs import manager |
|
from paddlers.models.ppseg.transforms import Compose |
|
|
|
|
|
@manager.DATASETS.add_component |
|
class PSSLDataset(Dataset): |
|
""" |
|
The PSSL dataset for segmentation. PSSL is short for Pseudo Semantic Segmentation Labels, where the pseudo label |
|
is computed by the Consensus explanation algorithm. |
|
|
|
The PSSL refers to "Distilling Ensemble of Explanations for Weakly-Supervised Pre-Training of Image Segmentation |
|
Models" (https://arxiv.org/abs/2207.03335). |
|
|
|
The Consensus explanation refers to "Cross-Model Consensus of Explanations and Beyond for Image Classification |
|
Models: An Empirical Study" (https://arxiv.org/abs/2109.00707). |
|
|
|
To use this dataset, we need to additionally prepare the orignal ImageNet dataset, which has the folder structure |
|
as follows: |
|
|
|
imagenet_root |
|
| |
|
|--train |
|
| |--n01440764 |
|
| | |--n01440764_10026.JPEG |
|
| | |--... |
|
| |--nxxxxxxxx |
|
| |--... |
|
|
|
where only the "train" set is needed. |
|
|
|
The PSSL dataset has the folder structure as follows: |
|
|
|
pssl_root |
|
| |
|
|--train |
|
| |--n01440764 |
|
| | |--n01440764_10026.JPEG_eiseg.npz |
|
| | |--... |
|
| |--nxxxxxxxx |
|
| |--... |
|
| |
|
|--imagenet_lsvrc_2015_synsets.txt |
|
|--train.txt |
|
|
|
where "train.txt" and "imagenet_lsvrc_2015_synsets.txt" are included in the PSSL dataset. |
|
|
|
Args: |
|
transforms (list): Transforms for image. |
|
imagenet_root (str): The path to the original ImageNet dataset. |
|
pssl_root (str): The path to the PSSL dataset. |
|
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. |
|
edge (bool, optional): Whether to compute edge while training. Default: False. |
|
""" |
|
ignore_index = 1001 # 0~999 is target class, 1000 is bg |
|
NUM_CLASSES = 1001 # consider target class and bg |
|
|
|
def __init__(self, |
|
transforms, |
|
imagenet_root, |
|
pssl_root, |
|
mode='train', |
|
edge=False): |
|
mode = mode.lower() |
|
if mode not in ['train']: |
|
raise ValueError("mode should be 'train', but got {}.".format(mode)) |
|
if transforms is None: |
|
raise ValueError("`transforms` is necessary, but it is None.") |
|
|
|
self.transforms = Compose(transforms) |
|
self.mode = mode |
|
self.edge = edge |
|
|
|
self.num_classes = self.NUM_CLASSES |
|
self.ignore_index = self.num_classes # 1001 |
|
self.file_list = [] |
|
self.class_id_dict = {} |
|
|
|
if imagenet_root is None or not os.path.isdir(pssl_root): |
|
raise ValueError( |
|
"The dataset is not Found or the folder structure is nonconfoumance." |
|
) |
|
|
|
train_list_file = os.path.join(pssl_root, "train.txt") |
|
if not os.path.exists(train_list_file): |
|
raise ValueError("Train list file isn't exists.") |
|
for idx, line in enumerate(open(train_list_file)): |
|
# line: train/n04118776/n04118776_45912.JPEG_eiseg.npz |
|
label_path = line.strip() |
|
img_path = label_path.split('.JPEG')[0] + '.JPEG' |
|
label_path = os.path.join(pssl_root, label_path) |
|
img_path = os.path.join(imagenet_root, img_path) |
|
self.file_list.append([img_path, label_path]) |
|
|
|
# mapping class name to class id. |
|
class_id_file = os.path.join(pssl_root, |
|
"imagenet_lsvrc_2015_synsets.txt") |
|
if not os.path.exists(class_id_file): |
|
raise ValueError("Class id file isn't exists.") |
|
for idx, line in enumerate(open(class_id_file)): |
|
class_name = line.strip() |
|
self.class_id_dict[class_name] = idx |
|
|
|
def __getitem__(self, idx): |
|
image_path, label_path = self.file_list[idx] |
|
|
|
# transform label |
|
class_name = (image_path.split('/')[-1]).split('_')[0] |
|
class_id = self.class_id_dict[class_name] |
|
|
|
pssl_seg = np.load(label_path)['arr_0'] |
|
gt_semantic_seg = np.zeros_like(pssl_seg, dtype=np.int64) + 1000 |
|
# [0, 999] for imagenet classes, 1000 for background, others(-1) will be ignored during training. |
|
gt_semantic_seg[pssl_seg == 1] = class_id |
|
|
|
im, label = self.transforms(im=image_path, label=gt_semantic_seg) |
|
|
|
return im, label
|
|
|