@ -2,12 +2,9 @@ import argparse
import copy
import os
import os . path as osp
import shutil
import tempfile
import mmcv
import torch
import torch . distributed as dist
from mmcv . parallel import MMDataParallel , MMDistributedDataParallel
from mmcv . runner import ( get_dist_info , init_dist , load_checkpoint ,
wrap_fp16_model )
@ -16,8 +13,8 @@ from pycocotools.cocoeval import COCOeval
from robustness_eval import get_results
from mmdet import datasets
from mmdet . apis import set_random_seed
from mmdet . core import encode_mask_results , e val_map
from mmdet . apis import multi_gpu_test , set_random_seed , single_gpu_test
from mmdet . core import eval_map
from mmdet . datasets import build_dataloader , build_dataset
from mmdet . models import build_detector
@ -91,99 +88,6 @@ def voc_eval_with_return(result_file,
return mean_ap , eval_results
def single_gpu_test ( model , data_loader , show = False ) :
model . eval ( )
results = [ ]
dataset = data_loader . dataset
prog_bar = mmcv . ProgressBar ( len ( dataset ) )
for i , data in enumerate ( data_loader ) :
with torch . no_grad ( ) :
result = model ( return_loss = False , rescale = not show , * * data )
if show :
model . module . show_result ( data , result , dataset . img_norm_cfg )
# encode mask results
if isinstance ( result [ 0 ] , tuple ) :
result = [ ( bbox_results , encode_mask_results ( mask_results ) )
for bbox_results , mask_results in result ]
results . extend ( result )
batch_size = len ( result )
for _ in range ( batch_size ) :
prog_bar . update ( )
return results
def multi_gpu_test ( model , data_loader , tmpdir = None ) :
model . eval ( )
results = [ ]
dataset = data_loader . dataset
rank , world_size = get_dist_info ( )
if rank == 0 :
prog_bar = mmcv . ProgressBar ( len ( dataset ) )
for i , data in enumerate ( data_loader ) :
with torch . no_grad ( ) :
result = model ( return_loss = False , rescale = True , * * data )
# encode mask results
if isinstance ( result [ 0 ] , tuple ) :
result = [ ( bbox_results , encode_mask_results ( mask_results ) )
for bbox_results , mask_results in result ]
results . extend ( result )
if rank == 0 :
batch_size = len ( result )
for _ in range ( batch_size * world_size ) :
prog_bar . update ( )
# collect results from all ranks
results = collect_results ( results , len ( dataset ) , tmpdir )
return results
def collect_results ( result_part , size , tmpdir = None ) :
rank , world_size = get_dist_info ( )
# create a tmp dir if it is not specified
if tmpdir is None :
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch . full ( ( MAX_LEN , ) ,
32 ,
dtype = torch . uint8 ,
device = ' cuda ' )
if rank == 0 :
tmpdir = tempfile . mkdtemp ( )
tmpdir = torch . tensor (
bytearray ( tmpdir . encode ( ) ) , dtype = torch . uint8 , device = ' cuda ' )
dir_tensor [ : len ( tmpdir ) ] = tmpdir
dist . broadcast ( dir_tensor , 0 )
tmpdir = dir_tensor . cpu ( ) . numpy ( ) . tobytes ( ) . decode ( ) . rstrip ( )
else :
mmcv . mkdir_or_exist ( tmpdir )
# dump the part result to the dir
mmcv . dump ( result_part , osp . join ( tmpdir , f ' part_ { rank } .pkl ' ) )
dist . barrier ( )
# collect all parts
if rank != 0 :
return None
else :
# load results of all parts from tmp dir
part_list = [ ]
for i in range ( world_size ) :
part_file = osp . join ( tmpdir , f ' part_ { i } .pkl ' )
part_list . append ( mmcv . load ( part_file ) )
# sort the results
ordered_results = [ ]
for res in zip ( * part_list ) :
ordered_results . extend ( list ( res ) )
# the dataloader may pad some samples
ordered_results = ordered_results [ : size ]
# remove tmp dir
shutil . rmtree ( tmpdir )
return ordered_results
def parse_args ( ) :
parser = argparse . ArgumentParser ( description = ' MMDet test detector ' )
parser . add_argument ( ' config ' , help = ' test config file path ' )
@ -228,6 +132,13 @@ def parse_args():
parser . add_argument (
' --workers ' , type = int , default = 32 , help = ' workers per gpu ' )
parser . add_argument ( ' --show ' , action = ' store_true ' , help = ' show results ' )
parser . add_argument (
' --show-dir ' , help = ' directory where painted images will be saved ' )
parser . add_argument (
' --show-score-thr ' ,
type = float ,
default = 0.3 ,
help = ' score threshold (default: 0.3) ' )
parser . add_argument ( ' --tmpdir ' , help = ' tmp dir for writing some results ' )
parser . add_argument ( ' --seed ' , type = int , default = None , help = ' random seed ' )
parser . add_argument (
@ -377,7 +288,14 @@ def main():
if not distributed :
model = MMDataParallel ( model , device_ids = [ 0 ] )
outputs = single_gpu_test ( model , data_loader , args . show )
show_dir = args . show_dir
if show_dir is not None :
show_dir = osp . join ( show_dir , corruption )
show_dir = osp . join ( show_dir , corruption_severity )
if not osp . exists ( show_dir ) :
osp . makedirs ( show_dir )
outputs = single_gpu_test ( model , data_loader , args . show ,
show_dir , args . show_score_thr )
else :
model = MMDistributedDataParallel (
model . cuda ( ) ,