@ -0,0 +1,12 @@ |
||||
.vscode/ |
||||
*.pyc |
||||
*.DS_Store |
||||
*.swp |
||||
*.pth |
||||
tmp.* |
||||
*/.ipynb_checkpoints/* |
||||
|
||||
logs/ |
||||
weights/ |
||||
dump/ |
||||
src/loftr/utils/superglue.py |
@ -0,0 +1,5 @@ |
||||
0022_0.1_0.3.npz |
||||
0015_0.1_0.3.npz |
||||
0015_0.3_0.5.npz |
||||
0022_0.3_0.5.npz |
||||
0022_0.5_0.7.npz |
After Width: | Height: | Size: 289 KiB |
After Width: | Height: | Size: 373 KiB |
After Width: | Height: | Size: 287 KiB |
After Width: | Height: | Size: 359 KiB |
After Width: | Height: | Size: 459 KiB |
After Width: | Height: | Size: 446 KiB |
After Width: | Height: | Size: 142 KiB |
After Width: | Height: | Size: 495 KiB |
After Width: | Height: | Size: 587 KiB |
After Width: | Height: | Size: 519 KiB |
After Width: | Height: | Size: 632 KiB |
After Width: | Height: | Size: 326 KiB |
After Width: | Height: | Size: 415 KiB |
After Width: | Height: | Size: 186 KiB |
After Width: | Height: | Size: 184 KiB |
After Width: | Height: | Size: 190 KiB |
After Width: | Height: | Size: 184 KiB |
After Width: | Height: | Size: 250 KiB |
After Width: | Height: | Size: 249 KiB |
After Width: | Height: | Size: 224 KiB |
After Width: | Height: | Size: 188 KiB |
After Width: | Height: | Size: 239 KiB |
After Width: | Height: | Size: 204 KiB |
After Width: | Height: | Size: 233 KiB |
After Width: | Height: | Size: 313 KiB |
After Width: | Height: | Size: 271 KiB |
After Width: | Height: | Size: 260 KiB |
After Width: | Height: | Size: 519 KiB |
After Width: | Height: | Size: 294 KiB |
After Width: | Height: | Size: 331 KiB |
After Width: | Height: | Size: 312 KiB |
After Width: | Height: | Size: 281 KiB |
After Width: | Height: | Size: 273 KiB |
After Width: | Height: | Size: 273 KiB |
After Width: | Height: | Size: 249 KiB |
After Width: | Height: | Size: 203 KiB |
After Width: | Height: | Size: 157 KiB |
After Width: | Height: | Size: 214 KiB |
After Width: | Height: | Size: 198 KiB |
After Width: | Height: | Size: 217 KiB |
After Width: | Height: | Size: 256 KiB |
After Width: | Height: | Size: 249 KiB |
After Width: | Height: | Size: 254 KiB |
@ -0,0 +1 @@ |
||||
test.npz |
@ -0,0 +1,102 @@ |
||||
{ |
||||
"scene0707_00": 15, |
||||
"scene0708_00": 15, |
||||
"scene0709_00": 15, |
||||
"scene0710_00": 15, |
||||
"scene0711_00": 15, |
||||
"scene0712_00": 15, |
||||
"scene0713_00": 15, |
||||
"scene0714_00": 15, |
||||
"scene0715_00": 15, |
||||
"scene0716_00": 15, |
||||
"scene0717_00": 15, |
||||
"scene0718_00": 15, |
||||
"scene0719_00": 15, |
||||
"scene0720_00": 15, |
||||
"scene0721_00": 15, |
||||
"scene0722_00": 15, |
||||
"scene0723_00": 15, |
||||
"scene0724_00": 15, |
||||
"scene0725_00": 15, |
||||
"scene0726_00": 15, |
||||
"scene0727_00": 15, |
||||
"scene0728_00": 15, |
||||
"scene0729_00": 15, |
||||
"scene0730_00": 15, |
||||
"scene0731_00": 15, |
||||
"scene0732_00": 15, |
||||
"scene0733_00": 15, |
||||
"scene0734_00": 15, |
||||
"scene0735_00": 15, |
||||
"scene0736_00": 15, |
||||
"scene0737_00": 15, |
||||
"scene0738_00": 15, |
||||
"scene0739_00": 15, |
||||
"scene0740_00": 15, |
||||
"scene0741_00": 15, |
||||
"scene0742_00": 15, |
||||
"scene0743_00": 15, |
||||
"scene0744_00": 15, |
||||
"scene0745_00": 15, |
||||
"scene0746_00": 15, |
||||
"scene0747_00": 15, |
||||
"scene0748_00": 15, |
||||
"scene0749_00": 15, |
||||
"scene0750_00": 15, |
||||
"scene0751_00": 15, |
||||
"scene0752_00": 15, |
||||
"scene0753_00": 15, |
||||
"scene0754_00": 15, |
||||
"scene0755_00": 15, |
||||
"scene0756_00": 15, |
||||
"scene0757_00": 15, |
||||
"scene0758_00": 15, |
||||
"scene0759_00": 15, |
||||
"scene0760_00": 15, |
||||
"scene0761_00": 15, |
||||
"scene0762_00": 15, |
||||
"scene0763_00": 15, |
||||
"scene0764_00": 15, |
||||
"scene0765_00": 15, |
||||
"scene0766_00": 15, |
||||
"scene0767_00": 15, |
||||
"scene0768_00": 15, |
||||
"scene0769_00": 15, |
||||
"scene0770_00": 15, |
||||
"scene0771_00": 15, |
||||
"scene0772_00": 15, |
||||
"scene0773_00": 15, |
||||
"scene0774_00": 15, |
||||
"scene0775_00": 15, |
||||
"scene0776_00": 15, |
||||
"scene0777_00": 15, |
||||
"scene0778_00": 15, |
||||
"scene0779_00": 15, |
||||
"scene0780_00": 15, |
||||
"scene0781_00": 15, |
||||
"scene0782_00": 15, |
||||
"scene0783_00": 15, |
||||
"scene0784_00": 15, |
||||
"scene0785_00": 15, |
||||
"scene0786_00": 15, |
||||
"scene0787_00": 15, |
||||
"scene0788_00": 15, |
||||
"scene0789_00": 15, |
||||
"scene0790_00": 15, |
||||
"scene0791_00": 15, |
||||
"scene0792_00": 15, |
||||
"scene0793_00": 15, |
||||
"scene0794_00": 15, |
||||
"scene0795_00": 15, |
||||
"scene0796_00": 15, |
||||
"scene0797_00": 15, |
||||
"scene0798_00": 15, |
||||
"scene0799_00": 15, |
||||
"scene0800_00": 15, |
||||
"scene0801_00": 15, |
||||
"scene0802_00": 15, |
||||
"scene0803_00": 15, |
||||
"scene0804_00": 15, |
||||
"scene0805_00": 15, |
||||
"scene0806_00": 15 |
||||
} |
@ -0,0 +1,31 @@ |
||||
""" |
||||
The data config will be the last one merged into the main config. |
||||
Setups in data configs will override all existed setups! |
||||
""" |
||||
|
||||
from yacs.config import CfgNode as CN |
||||
_CN = CN() |
||||
_CN.DATASET = CN() |
||||
_CN.TRAINER = CN() |
||||
|
||||
# training data config |
||||
_CN.DATASET.TRAIN_DATA_ROOT = None |
||||
_CN.DATASET.TRAIN_NPZ_ROOT = None |
||||
_CN.DATASET.TRAIN_LIST_PATH = None |
||||
_CN.DATASET.TRAIN_INTRINSIC_PATH = None |
||||
# validation set config |
||||
_CN.DATASET.VAL_DATA_ROOT = None |
||||
_CN.DATASET.VAL_NPZ_ROOT = None |
||||
_CN.DATASET.VAL_LIST_PATH = None |
||||
_CN.DATASET.VAL_INTRINSIC_PATH = None |
||||
|
||||
# testing data config |
||||
_CN.DATASET.TEST_DATA_ROOT = None |
||||
_CN.DATASET.TEST_NPZ_ROOT = None |
||||
_CN.DATASET.TEST_LIST_PATH = None |
||||
_CN.DATASET.TEST_INTRINSIC_PATH = None |
||||
|
||||
# dataset config |
||||
_CN.DATASET.MIN_OVERLAP_SCORE = 0.4 |
||||
|
||||
cfg = _CN |
@ -0,0 +1,11 @@ |
||||
from configs.data.base import cfg |
||||
|
||||
TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info" |
||||
|
||||
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" |
||||
cfg.DATASET.TEST_DATA_ROOT = "/data/MegaDepth/megadepth_test_1500" |
||||
cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" |
||||
cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" |
||||
|
||||
cfg.DATASET.MGDPT_IMG_RESIZE = 840 |
||||
cfg.DATASET.MIN_OVERLAP_SCORE = 0.0 |
@ -0,0 +1,11 @@ |
||||
from configs.data.base import cfg |
||||
|
||||
TEST_BASE_PATH = "assets/scannet_test_1500" |
||||
|
||||
cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" |
||||
cfg.DATASET.TEST_DATA_ROOT = "/data/scannet/scannet_test_1500" |
||||
cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" |
||||
cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" |
||||
cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" |
||||
|
||||
cfg.DATASET.MIN_OVERLAP_SCORE = 0.0 |
@ -0,0 +1,3 @@ |
||||
from src.config.default import _CN as cfg |
||||
|
||||
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' |
@ -0,0 +1,3 @@ |
||||
from src.config.default import _CN as cfg |
||||
|
||||
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' |
@ -0,0 +1,14 @@ |
||||
name: loftr |
||||
channels: |
||||
# - https://dx-mirrors.sensetime.com/anaconda/cloud/pytorch |
||||
- pytorch |
||||
- conda-forge |
||||
- defaults |
||||
dependencies: |
||||
- python=3.8 |
||||
- cudatoolkit=10.2 |
||||
- pytorch=1.8.0 |
||||
- pytorch-lightning<=1.1.8 # https://github.com/PyTorchLightning/pytorch-lightning/issues/6318 |
||||
- pip |
||||
- pip: |
||||
- -r file:requirements.txt |
@ -0,0 +1,14 @@ |
||||
opencv_python==4.4.0.46 |
||||
albumentations==0.5.1 --no-binary=imgaug,albumentations |
||||
ray>=1.0.1 |
||||
einops==0.3.0 |
||||
kornia==0.4.1 |
||||
loguru==0.5.3 |
||||
yacs>=0.1.8 |
||||
tqdm |
||||
autopep8 |
||||
pylint |
||||
ipython |
||||
jupyterlab |
||||
matplotlib |
||||
h5py==3.1.0 |
@ -0,0 +1,29 @@ |
||||
#!/bin/bash -l |
||||
|
||||
SCRIPTPATH=$(dirname $(readlink -f "$0")) |
||||
PROJECT_DIR="${SCRIPTPATH}/../../" |
||||
|
||||
# conda activate loftr |
||||
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH |
||||
cd $PROJECT_DIR |
||||
|
||||
data_cfg_path="configs/data/scannet_test_1500.py" |
||||
main_cfg_path="configs/loftr/loftr_ds.py" |
||||
ckpt_path="weights/indoor_ds.ckpt" |
||||
dump_dir="dump/loftr_ds_indoor" |
||||
profiler_name="inference" |
||||
n_nodes=1 # mannually keep this the same with --nodes |
||||
n_gpus_per_node=-1 |
||||
torch_num_workers=4 |
||||
batch_size=1 # per gpu |
||||
|
||||
python -u ./test.py \ |
||||
${data_cfg_path} \ |
||||
${main_cfg_path} \ |
||||
--ckpt_path=${ckpt_path} \ |
||||
--dump_dir=${dump_dir} \ |
||||
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ |
||||
--batch_size=${batch_size} --num_workers=${torch_num_workers}\ |
||||
--profiler_name=${profiler_name} \ |
||||
--benchmark |
||||
|
@ -0,0 +1,29 @@ |
||||
#!/bin/bash -l |
||||
|
||||
SCRIPTPATH=$(dirname $(readlink -f "$0")) |
||||
PROJECT_DIR="${SCRIPTPATH}/../../" |
||||
|
||||
# conda activate loftr |
||||
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH |
||||
cd $PROJECT_DIR |
||||
|
||||
data_cfg_path="configs/data/scannet_test_1500.py" |
||||
main_cfg_path="configs/loftr/loftr_ot.py" |
||||
ckpt_path="weights/indoor_ot.ckpt" |
||||
dump_dir="dump/loftr_ot_indoor" |
||||
profiler_name="inference" |
||||
n_nodes=1 # mannually keep this the same with --nodes |
||||
n_gpus_per_node=-1 |
||||
torch_num_workers=4 |
||||
batch_size=1 # per gpu |
||||
|
||||
python -u ./test.py \ |
||||
${data_cfg_path} \ |
||||
${main_cfg_path} \ |
||||
--ckpt_path=${ckpt_path} \ |
||||
--dump_dir=${dump_dir} \ |
||||
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ |
||||
--batch_size=${batch_size} --num_workers=${torch_num_workers}\ |
||||
--profiler_name=${profiler_name} \ |
||||
--benchmark |
||||
|
@ -0,0 +1,29 @@ |
||||
#!/bin/bash -l |
||||
|
||||
SCRIPTPATH=$(dirname $(readlink -f "$0")) |
||||
PROJECT_DIR="${SCRIPTPATH}/../../" |
||||
|
||||
# conda activate loftr |
||||
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH |
||||
cd $PROJECT_DIR |
||||
|
||||
data_cfg_path="configs/data/megadepth_test_1500.py" |
||||
main_cfg_path="configs/loftr/loftr_ds.py" |
||||
ckpt_path="weights/outdoor_ds.ckpt" |
||||
dump_dir="dump/loftr_ds_outdoor" |
||||
profiler_name="inference" |
||||
n_nodes=1 # mannually keep this the same with --nodes |
||||
n_gpus_per_node=-1 |
||||
torch_num_workers=4 |
||||
batch_size=1 # per gpu |
||||
|
||||
python -u ./test.py \ |
||||
${data_cfg_path} \ |
||||
${main_cfg_path} \ |
||||
--ckpt_path=${ckpt_path} \ |
||||
--dump_dir=${dump_dir} \ |
||||
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ |
||||
--batch_size=${batch_size} --num_workers=${torch_num_workers}\ |
||||
--profiler_name=${profiler_name} \ |
||||
--benchmark |
||||
|
@ -0,0 +1,29 @@ |
||||
#!/bin/bash -l |
||||
|
||||
SCRIPTPATH=$(dirname $(readlink -f "$0")) |
||||
PROJECT_DIR="${SCRIPTPATH}/../../" |
||||
|
||||
# conda activate loftr |
||||
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH |
||||
cd $PROJECT_DIR |
||||
|
||||
data_cfg_path="configs/data/megadepth_test_1500.py" |
||||
main_cfg_path="configs/loftr/loftr_ot.py" |
||||
ckpt_path="weights/outdoor_ot.ckpt" |
||||
dump_dir="dump/loftr_ot_outdoor" |
||||
profiler_name="inference" |
||||
n_nodes=1 # mannually keep this the same with --nodes |
||||
n_gpus_per_node=-1 |
||||
torch_num_workers=4 |
||||
batch_size=1 # per gpu |
||||
|
||||
python -u ./test.py \ |
||||
${data_cfg_path} \ |
||||
${main_cfg_path} \ |
||||
--ckpt_path=${ckpt_path} \ |
||||
--dump_dir=${dump_dir} \ |
||||
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ |
||||
--batch_size=${batch_size} --num_workers=${torch_num_workers}\ |
||||
--profiler_name=${profiler_name} \ |
||||
--benchmark |
||||
|
@ -0,0 +1,120 @@ |
||||
from yacs.config import CfgNode as CN |
||||
_CN = CN() |
||||
|
||||
############## ↓ LoFTR Pipeline ↓ ############## |
||||
_CN.LOFTR = CN() |
||||
_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN' |
||||
_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] |
||||
_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd |
||||
_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True |
||||
|
||||
# 1. LoFTR-backbone (local feature CNN) config |
||||
_CN.LOFTR.RESNETFPN = CN() |
||||
_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128 |
||||
_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 |
||||
|
||||
# 2. LoFTR-coarse module config |
||||
_CN.LOFTR.COARSE = CN() |
||||
_CN.LOFTR.COARSE.D_MODEL = 256 |
||||
_CN.LOFTR.COARSE.D_FFN = 256 |
||||
_CN.LOFTR.COARSE.NHEAD = 8 |
||||
_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 |
||||
_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] |
||||
|
||||
# 3. Coarse-Matching config |
||||
_CN.LOFTR.MATCH_COARSE = CN() |
||||
_CN.LOFTR.MATCH_COARSE.THR = 0.2 |
||||
_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2 |
||||
_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] |
||||
_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 |
||||
_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3 |
||||
_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 |
||||
_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False |
||||
_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory |
||||
_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock |
||||
|
||||
# 4. LoFTR-fine module config |
||||
_CN.LOFTR.FINE = CN() |
||||
_CN.LOFTR.FINE.D_MODEL = 128 |
||||
_CN.LOFTR.FINE.D_FFN = 128 |
||||
_CN.LOFTR.FINE.NHEAD = 8 |
||||
_CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1 |
||||
_CN.LOFTR.FINE.ATTENTION = 'linear' |
||||
|
||||
|
||||
############## Dataset ############## |
||||
_CN.DATASET = CN() |
||||
# 1. data config |
||||
# training and validating |
||||
_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] |
||||
_CN.DATASET.TRAIN_DATA_ROOT = None |
||||
_CN.DATASET.TRAIN_NPZ_ROOT = None |
||||
_CN.DATASET.TRAIN_LIST_PATH = None |
||||
_CN.DATASET.TRAIN_INTRINSIC_PATH = None |
||||
_CN.DATASET.VAL_DATA_ROOT = None |
||||
_CN.DATASET.VAL_NPZ_ROOT = None |
||||
_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file |
||||
_CN.DATASET.VAL_INTRINSIC_PATH = None |
||||
# testing |
||||
_CN.DATASET.TEST_DATA_SOURCE = None |
||||
_CN.DATASET.TEST_DATA_ROOT = None |
||||
_CN.DATASET.TEST_NPZ_ROOT = None |
||||
_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file |
||||
_CN.DATASET.TEST_INTRINSIC_PATH = None |
||||
|
||||
# 2. dataset config |
||||
# general options |
||||
_CN.DATASET.MIN_OVERLAP_SCORE = 0.4 # discard data with overlap_score < min_overlap_score |
||||
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] |
||||
|
||||
# MegaDepth options |
||||
_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. |
||||
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE |
||||
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 |
||||
_CN.DATASET.MGDPT_DF = 8 |
||||
|
||||
############## Trainer ############## |
||||
_CN.TRAINER = CN() |
||||
|
||||
# plotting related |
||||
_CN.TRAINER.ENABLE_PLOTTING = True |
||||
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting |
||||
|
||||
# geometric metrics and pose solver |
||||
_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) |
||||
_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] |
||||
_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] |
||||
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5 |
||||
_CN.TRAINER.RANSAC_CONF = 0.99999 |
||||
_CN.TRAINER.RANSAC_MAX_ITERS = 10000 |
||||
_CN.TRAINER.USE_MAGSACPP = False |
||||
|
||||
# data sampler for train_dataloader |
||||
_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] |
||||
# 'scene_balance' config |
||||
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 |
||||
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not |
||||
_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not |
||||
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data |
||||
# 'random' config |
||||
_CN.TRAINER.RDM_REPLACEMENT = True |
||||
_CN.TRAINER.RDM_NUM_SAMPLES = None |
||||
|
||||
# gradient clipping |
||||
_CN.TRAINER.GRADIENT_CLIPPING = 0.5 |
||||
|
||||
# reproducibility |
||||
# This seed affects the data sampling. With the same seed, the data sampling is promised |
||||
# to be the same. When resume training from a checkpoint, it's better to use a different |
||||
# seed, otherwise the sampled data will be exactly the same as before resuming, which will |
||||
# cause less unique data items sampled during the entire training. |
||||
# Use of different seed value might affect the final training result, since not all data items |
||||
# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) |
||||
_CN.TRAINER.SEED = 66 |
||||
|
||||
|
||||
def get_cfg_defaults(): |
||||
"""Get a yacs CfgNode object with default values for my_project.""" |
||||
# Return a clone so that the defaults will not be altered |
||||
# This is for the "local variable" use pattern |
||||
return _CN.clone() |
@ -0,0 +1,124 @@ |
||||
import os.path as osp |
||||
import numpy as np |
||||
import torch |
||||
import torch.nn.functional as F |
||||
from torch.utils.data import Dataset |
||||
from loguru import logger |
||||
|
||||
from src.utils.dataset import read_megadepth_gray, read_megadepth_depth |
||||
|
||||
|
||||
class MegaDepthDataset(Dataset): |
||||
def __init__(self, |
||||
root_dir, |
||||
npz_path, |
||||
mode='train', |
||||
min_overlap_score=0.4, |
||||
img_resize=None, |
||||
df=None, |
||||
img_padding=False, |
||||
depth_padding=False, |
||||
augment_fn=None, |
||||
**kwargs): |
||||
""" |
||||
Manage one scene(npz_path) of MegaDepth dataset. |
||||
Args: |
||||
root_dir (str): megadepth root directory that has `phoenix`. |
||||
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. |
||||
mode (str): options are ['train', 'val', 'test'] |
||||
min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. |
||||
img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. |
||||
This is useful during training with batches and testing with memory intensive algorithms. |
||||
df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. |
||||
img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. |
||||
depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. |
||||
augment_fn (callable, optional): augments images with pre-defined visual effects. |
||||
""" |
||||
super().__init__() |
||||
self.root_dir = root_dir |
||||
self.mode = mode |
||||
self.scene_id = npz_path.split('.')[0] |
||||
|
||||
# prepare scene_info and pair_info |
||||
if mode == 'test' and min_overlap_score != 0: |
||||
logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") |
||||
min_overlap_score = 0 |
||||
self.scene_info = np.load(npz_path, allow_pickle=True) |
||||
self.pair_infos = self.scene_info['pair_infos'].copy() |
||||
del self.scene_info['pair_infos'] |
||||
self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] |
||||
|
||||
# parameters for image resizing, padding and depthmap padding |
||||
if mode == 'train': |
||||
assert img_resize is not None and img_padding and depth_padding |
||||
self.img_resize = img_resize |
||||
self.df = df |
||||
self.img_padding = img_padding |
||||
self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. |
||||
|
||||
# for training LoFTR |
||||
self.augment_fn = augment_fn if mode == 'train' else None |
||||
self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) |
||||
|
||||
def __len__(self): |
||||
return len(self.pair_infos) |
||||
|
||||
def __getitem__(self, idx): |
||||
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] |
||||
|
||||
# read grayscale image and mask. (1, h, w) and (h, w) |
||||
img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) |
||||
img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) |
||||
image0, mask0, scale0 = read_megadepth_gray( |
||||
img_name0, self.img_resize, self.df, self.img_padding, |
||||
np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) |
||||
image1, mask1, scale1 = read_megadepth_gray( |
||||
img_name1, self.img_resize, self.df, self.img_padding, |
||||
np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) |
||||
|
||||
# read depth. shape: (h, w) |
||||
if self.mode in ['train', 'val']: |
||||
depth0 = read_megadepth_depth( |
||||
osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) |
||||
depth1 = read_megadepth_depth( |
||||
osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) |
||||
else: |
||||
depth0 = depth1 = torch.tensor([]) |
||||
|
||||
# read intrinsics of original size |
||||
K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) |
||||
K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) |
||||
|
||||
# read and compute relative poses |
||||
T0 = self.scene_info['poses'][idx0] |
||||
T1 = self.scene_info['poses'][idx1] |
||||
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) |
||||
T_1to0 = T_0to1.inverse() |
||||
|
||||
data = { |
||||
'image0': image0, # (1, h, w) |
||||
'depth0': depth0, # (h, w) |
||||
'image1': image1, |
||||
'depth1': depth1, |
||||
'T_0to1': T_0to1, # (4, 4) |
||||
'T_1to0': T_1to0, |
||||
'K0': K_0, # (3, 3) |
||||
'K1': K_1, |
||||
'scale0': scale0, # [scale_w, scale_h] |
||||
'scale1': scale1, |
||||
'dataset_name': 'MegaDepth', |
||||
'scene_id': self.scene_id, |
||||
'pair_id': idx, |
||||
'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), |
||||
} |
||||
|
||||
# for LoFTR training |
||||
if mask0 is not None: # img_padding is True |
||||
if self.coarse_scale: |
||||
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), |
||||
scale_factor=self.coarse_scale, |
||||
mode='nearest', |
||||
recompute_scale_factor=False)[0].bool() |
||||
data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) |
||||
|
||||
return data |
@ -0,0 +1,90 @@ |
||||
from os import path as osp |
||||
import numpy as np |
||||
import torch |
||||
import torch.utils as utils |
||||
from src.utils.dataset import read_scannet_gray, read_scannet_depth |
||||
|
||||
|
||||
class ScanNetDataset(utils.data.Dataset): |
||||
def __init__(self, |
||||
root_dir, |
||||
npz_path, |
||||
intrinsic_path, |
||||
mode='train', |
||||
min_overlap_score=0.4, |
||||
augment_fn=None, |
||||
**kwargs): |
||||
"""Manage one scene of ScanNet Dataset. |
||||
Args: |
||||
root_dir (str): ScanNet root directory that contains scene folders. |
||||
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. |
||||
intrinsic_path (str): path to depth-camera intrinsic file. |
||||
mode (str): options are ['train', 'val', 'test']. |
||||
augment_fn (callable, optional): augments images with pre-defined visual effects. |
||||
""" |
||||
super().__init__() |
||||
self.root_dir = root_dir |
||||
self.mode = mode |
||||
|
||||
# prepare data_names, intrinsics and extrinsics(T) |
||||
with np.load(npz_path) as data: |
||||
self.data_names = data['name'] |
||||
self.T_1to2s = data['rel_pose'] |
||||
# min_overlap_score criterion |
||||
if 'score' in data.keys() and mode not in ['val' or 'test']: |
||||
kept_mask = data['score'] > min_overlap_score |
||||
self.data_names = self.data_names[kept_mask] |
||||
self.T_1to2s = self.T_1to2s[kept_mask] |
||||
self.intrinsics = dict(np.load(intrinsic_path)) |
||||
|
||||
# for training LoFTR |
||||
self.augment_fn = augment_fn if mode == 'train' else None |
||||
|
||||
def __len__(self): |
||||
return len(self.data_names) |
||||
|
||||
def __getitem__(self, idx): |
||||
data_name = self.data_names[idx] |
||||
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name |
||||
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' |
||||
|
||||
# read the grayscale image which will be resized to (1, 480, 640) |
||||
img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') |
||||
img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') |
||||
image0 = read_scannet_gray(img_name0, resize=(640, 480), |
||||
augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) |
||||
image1 = read_scannet_gray(img_name1, resize=(640, 480), |
||||
augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) |
||||
|
||||
# read the depthmap which is stored as (480, 640) |
||||
if self.mode in ['train', 'val']: |
||||
depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) |
||||
depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) |
||||
else: |
||||
depth0 = depth1 = torch.tensor([]) |
||||
|
||||
# read the intrinsic of depthmap |
||||
K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) |
||||
|
||||
# read and compute relative poses |
||||
T_0to1 = torch.tensor(self.T_1to2s[idx].copy(), dtype=torch.float).reshape(3, 4) |
||||
T_0to1 = torch.cat([T_0to1, torch.tensor([[0., 0., 0., 1.]])], dim=0).reshape(4, 4) |
||||
T_1to0 = T_0to1.inverse() |
||||
|
||||
data = { |
||||
'image0': image0, # (1, h, w) |
||||
'depth0': depth0, # (h, w) |
||||
'image1': image1, |
||||
'depth1': depth1, |
||||
'T_0to1': T_0to1, # (4, 4) |
||||
'T_1to0': T_1to0, |
||||
'K0': K_0, # (3, 3) |
||||
'K1': K_1, |
||||
'dataset_name': 'scannet', |
||||
'scene_id': scene_name, |
||||
'pair_id': idx, |
||||
'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), |
||||
osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) |
||||
} |
||||
|
||||
return data |
@ -0,0 +1,137 @@ |
||||
from loguru import logger |
||||
from tqdm import tqdm |
||||
from os import path as osp |
||||
|
||||
import pytorch_lightning as pl |
||||
from torch import distributed as dist |
||||
from torch.utils.data import DataLoader, ConcatDataset, DistributedSampler |
||||
|
||||
from src.utils.augment import build_augmentor |
||||
from src.utils.dataloader import get_local_split |
||||
from src.datasets.megadepth import MegaDepthDataset |
||||
from src.datasets.scannet import ScanNetDataset |
||||
|
||||
|
||||
class MultiSceneDataModule(pl.LightningDataModule): |
||||
""" |
||||
For distributed training, each training process is assgined |
||||
only a part of the training scenes to reduce memory overhead. |
||||
""" |
||||
|
||||
def __init__(self, args, config): |
||||
super().__init__() |
||||
|
||||
# 1. data config |
||||
# Train and Val should from the same data source |
||||
self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE |
||||
self.test_data_source = config.DATASET.TEST_DATA_SOURCE |
||||
# training and validating |
||||
self.train_data_root = config.DATASET.TRAIN_DATA_ROOT |
||||
self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT |
||||
self.train_list_path = config.DATASET.TRAIN_LIST_PATH |
||||
self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH |
||||
self.val_data_root = config.DATASET.VAL_DATA_ROOT |
||||
self.val_npz_root = config.DATASET.VAL_NPZ_ROOT |
||||
self.val_list_path = config.DATASET.VAL_LIST_PATH |
||||
self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH |
||||
# testing |
||||
self.test_data_root = config.DATASET.TEST_DATA_ROOT |
||||
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT |
||||
self.test_list_path = config.DATASET.TEST_LIST_PATH |
||||
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH |
||||
|
||||
# 2. dataset config |
||||
# general options |
||||
self.min_overlap_score = config.DATASET.MIN_OVERLAP_SCORE # 0.4, omit data with overlap_score < min_overlap_score |
||||
self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] |
||||
|
||||
# MegaDepth options |
||||
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 |
||||
self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True |
||||
self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True |
||||
self.mgdpt_df = config.DATASET.MGDPT_DF # 8 |
||||
self.coarse_scale = 1 / config.LOFTR.RESOLUTION[0] # 0.125. for training loftr. |
||||
|
||||
# 3.loader parameters |
||||
self.test_loader_params = { |
||||
'batch_size': 1, |
||||
'shuffle': False, |
||||
'num_workers': args.num_workers, |
||||
'pin_memory': True |
||||
} |
||||
|
||||
self.seed = config.TRAINER.SEED # 66 |
||||
|
||||
def setup(self, stage=None): |
||||
""" |
||||
Setup train / val / test dataset. This method will be called by PL automatically. |
||||
Args: |
||||
stage (str): 'fit' in training phase, and 'test' in testing phase. |
||||
""" |
||||
|
||||
assert stage == 'test', "only support testing yet" |
||||
|
||||
try: |
||||
self.world_size = dist.get_world_size() |
||||
self.rank = dist.get_rank() |
||||
logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") |
||||
except AssertionError as ae: |
||||
self.world_size = 1 |
||||
self.rank = 0 |
||||
logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") |
||||
|
||||
self.test_dataset = self._setup_dataset(self.test_data_root, |
||||
self.test_npz_root, |
||||
self.test_list_path, |
||||
self.test_intrinsic_path, |
||||
mode='test') |
||||
logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') |
||||
|
||||
def _setup_dataset(self, data_root, split_npz_root, scene_list_path, intri_path, mode='train'): |
||||
""" Setup train / val / test set""" |
||||
with open(scene_list_path, 'r') as f: |
||||
npz_names = [name.split()[0] for name in f.readlines()] |
||||
|
||||
if mode == 'train': |
||||
local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) |
||||
else: |
||||
local_npz_names = npz_names |
||||
logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') |
||||
|
||||
return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path, mode=mode) |
||||
|
||||
def _build_concat_dataset(self, data_root, npz_names, npz_dir, intrinsic_path, mode): |
||||
datasets = [] |
||||
augment_fn = self.augment_fn if mode == 'train' else None |
||||
data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source |
||||
for npz_name in tqdm(npz_names, desc=f'[rank:{self.rank}], loading {mode} datasets', disable=int(self.rank) != 0): |
||||
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. |
||||
npz_path = osp.join(npz_dir, npz_name) |
||||
if data_source == 'ScanNet': |
||||
datasets.append( |
||||
ScanNetDataset(data_root, |
||||
npz_path, |
||||
intrinsic_path, |
||||
mode=mode, |
||||
min_overlap_score=self.min_overlap_score, |
||||
augment_fn=augment_fn)) |
||||
elif data_source == 'MegaDepth': |
||||
datasets.append( |
||||
MegaDepthDataset(data_root, |
||||
npz_path, |
||||
mode=mode, |
||||
min_overlap_score=self.min_overlap_score, |
||||
img_resize=self.mgdpt_img_resize, |
||||
df=self.mgdpt_df, |
||||
img_padding=self.mgdpt_img_pad, |
||||
depth_padding=self.mgdpt_depth_pad, |
||||
augment_fn=augment_fn, |
||||
coarse_scale=self.coarse_scale)) |
||||
else: |
||||
raise NotImplementedError() |
||||
return ConcatDataset(datasets) |
||||
|
||||
def test_dataloader(self, *args, **kwargs): |
||||
logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') |
||||
sampler = DistributedSampler(self.test_dataset, shuffle=False) |
||||
return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) |
@ -0,0 +1,92 @@ |
||||
import pprint |
||||
from loguru import logger |
||||
from pathlib import Path |
||||
import numpy as np |
||||
|
||||
import torch |
||||
import pytorch_lightning as pl |
||||
|
||||
from src.loftr import LoFTR |
||||
from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors, aggregate_metrics |
||||
|
||||
from src.utils.comm import gather |
||||
from src.utils.misc import lower_config, flattenList |
||||
from src.utils.profiler import PassThroughProfiler |
||||
|
||||
|
||||
class PL_LoFTR(pl.LightningModule): |
||||
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): |
||||
|
||||
super().__init__() |
||||
# Misc |
||||
self.config = config # full config |
||||
self.loftr_cfg = lower_config(self.config.LOFTR) |
||||
self.profiler = profiler or PassThroughProfiler() |
||||
self.dump_dir = dump_dir |
||||
|
||||
# Matcher: LoFTR |
||||
self.matcher = LoFTR(config=self.loftr_cfg) |
||||
|
||||
# Pretrained weights |
||||
if pretrained_ckpt: |
||||
self.matcher.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu')['state_dict']) |
||||
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") |
||||
|
||||
def test_step(self, batch, batch_idx): |
||||
with self.profiler.profile("LoFTR"): |
||||
self.matcher(batch) |
||||
|
||||
with self.profiler.profile("Copmute metrics"): |
||||
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match |
||||
compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair |
||||
|
||||
rel_pair_names = list(zip(*batch['pair_names'])) |
||||
bs = batch['image0'].size(0) |
||||
metrics = { |
||||
# to filter duplicate pairs caused by DistributedSampler |
||||
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], |
||||
'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], |
||||
'R_errs': batch['R_errs'], |
||||
't_errs': batch['t_errs'], |
||||
'inliers': batch['inliers']} |
||||
ret_dict = {'metrics': metrics} |
||||
|
||||
with self.profiler.profile("dump_results"): |
||||
if self.dump_dir is not None: |
||||
# dump results for further analysis |
||||
keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} |
||||
pair_names = list(zip(*batch['pair_names'])) |
||||
bs = batch['image0'].shape[0] |
||||
dumps = [] |
||||
for b_id in range(bs): |
||||
item = {} |
||||
mask = batch['m_bids'] == b_id |
||||
item['pair_names'] = pair_names[b_id] |
||||
item['identifier'] = '#'.join(rel_pair_names[b_id]) |
||||
for key in keys_to_save: |
||||
item[key] = batch[key][mask].cpu().numpy() |
||||
for key in ['R_errs', 't_errs', 'inliers']: |
||||
item[key] = batch[key][b_id] |
||||
dumps.append(item) |
||||
ret_dict['dumps'] = dumps |
||||
|
||||
return ret_dict |
||||
|
||||
def test_epoch_end(self, outputs): |
||||
# metrics: dict of list, numpy |
||||
_metrics = [o['metrics'] for o in outputs] |
||||
metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} |
||||
|
||||
# [{key: [{...}, *#bs]}, *#batch] |
||||
if self.dump_dir is not None: |
||||
Path(self.dump_dir).mkdir(parents=True, exist_ok=True) |
||||
_dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] |
||||
dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] |
||||
logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') |
||||
|
||||
if self.trainer.global_rank == 0: |
||||
print(self.profiler.summary()) |
||||
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) |
||||
logger.info('\n' + pprint.pformat(val_metrics_4tb)) |
||||
if self.dump_dir is not None: |
||||
np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps) |
@ -0,0 +1,2 @@ |
||||
from .loftr import LoFTR |
||||
from .utils.cvpr_ds_config import default_cfg |
@ -0,0 +1,11 @@ |
||||
from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 |
||||
|
||||
|
||||
def build_backbone(config): |
||||
if config['backbone_type'] == 'ResNetFPN': |
||||
if config['resolution'] == (8, 2): |
||||
return ResNetFPN_8_2(config['resnetfpn']) |
||||
elif config['resolution'] == (16, 4): |
||||
return ResNetFPN_16_4(config['resnetfpn']) |
||||
else: |
||||
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") |
@ -0,0 +1,199 @@ |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1): |
||||
"""1x1 convolution without padding""" |
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) |
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1): |
||||
"""3x3 convolution with padding""" |
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) |
||||
|
||||
|
||||
class BasicBlock(nn.Module): |
||||
def __init__(self, in_planes, planes, stride=1): |
||||
super().__init__() |
||||
self.conv1 = conv3x3(in_planes, planes, stride) |
||||
self.conv2 = conv3x3(planes, planes) |
||||
self.bn1 = nn.BatchNorm2d(planes) |
||||
self.bn2 = nn.BatchNorm2d(planes) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
if stride == 1: |
||||
self.downsample = None |
||||
else: |
||||
self.downsample = nn.Sequential( |
||||
conv1x1(in_planes, planes, stride=stride), |
||||
nn.BatchNorm2d(planes) |
||||
) |
||||
|
||||
def forward(self, x): |
||||
y = x |
||||
y = self.relu(self.bn1(self.conv1(y))) |
||||
y = self.bn2(self.conv2(y)) |
||||
|
||||
if self.downsample is not None: |
||||
x = self.downsample(x) |
||||
|
||||
return self.relu(x+y) |
||||
|
||||
|
||||
class ResNetFPN_8_2(nn.Module): |
||||
""" |
||||
ResNet+FPN, output resolution are 1/8 and 1/2. |
||||
Each block has 2 layers. |
||||
""" |
||||
|
||||
def __init__(self, config): |
||||
super().__init__() |
||||
# Config |
||||
block = BasicBlock |
||||
initial_dim = config['initial_dim'] |
||||
block_dims = config['block_dims'] |
||||
|
||||
# Class Variable |
||||
self.in_planes = initial_dim |
||||
|
||||
# Networks |
||||
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) |
||||
self.bn1 = nn.BatchNorm2d(initial_dim) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 |
||||
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 |
||||
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 |
||||
|
||||
# 3. FPN upsample |
||||
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) |
||||
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) |
||||
self.layer2_outconv2 = nn.Sequential( |
||||
conv3x3(block_dims[2], block_dims[2]), |
||||
nn.BatchNorm2d(block_dims[2]), |
||||
nn.LeakyReLU(), |
||||
conv3x3(block_dims[2], block_dims[1]), |
||||
) |
||||
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) |
||||
self.layer1_outconv2 = nn.Sequential( |
||||
conv3x3(block_dims[1], block_dims[1]), |
||||
nn.BatchNorm2d(block_dims[1]), |
||||
nn.LeakyReLU(), |
||||
conv3x3(block_dims[1], block_dims[0]), |
||||
) |
||||
|
||||
for m in self.modules(): |
||||
if isinstance(m, nn.Conv2d): |
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
||||
nn.init.constant_(m.weight, 1) |
||||
nn.init.constant_(m.bias, 0) |
||||
|
||||
def _make_layer(self, block, dim, stride=1): |
||||
layer1 = block(self.in_planes, dim, stride=stride) |
||||
layer2 = block(dim, dim, stride=1) |
||||
layers = (layer1, layer2) |
||||
|
||||
self.in_planes = dim |
||||
return nn.Sequential(*layers) |
||||
|
||||
def forward(self, x): |
||||
# ResNet Backbone |
||||
x0 = self.relu(self.bn1(self.conv1(x))) |
||||
x1 = self.layer1(x0) # 1/2 |
||||
x2 = self.layer2(x1) # 1/4 |
||||
x3 = self.layer3(x2) # 1/8 |
||||
|
||||
# FPN |
||||
x3_out = self.layer3_outconv(x3) |
||||
|
||||
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) |
||||
x2_out = self.layer2_outconv(x2) |
||||
x2_out = self.layer2_outconv2(x2_out+x3_out_2x) |
||||
|
||||
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) |
||||
x1_out = self.layer1_outconv(x1) |
||||
x1_out = self.layer1_outconv2(x1_out+x2_out_2x) |
||||
|
||||
return [x3_out, x1_out] |
||||
|
||||
|
||||
class ResNetFPN_16_4(nn.Module): |
||||
""" |
||||
ResNet+FPN, output resolution are 1/16 and 1/4. |
||||
Each block has 2 layers. |
||||
""" |
||||
|
||||
def __init__(self, config): |
||||
super().__init__() |
||||
# Config |
||||
block = BasicBlock |
||||
initial_dim = config['initial_dim'] |
||||
block_dims = config['block_dims'] |
||||
|
||||
# Class Variable |
||||
self.in_planes = initial_dim |
||||
|
||||
# Networks |
||||
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) |
||||
self.bn1 = nn.BatchNorm2d(initial_dim) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 |
||||
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 |
||||
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 |
||||
self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 |
||||
|
||||
# 3. FPN upsample |
||||
self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) |
||||
self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) |
||||
self.layer3_outconv2 = nn.Sequential( |
||||
conv3x3(block_dims[3], block_dims[3]), |
||||
nn.BatchNorm2d(block_dims[3]), |
||||
nn.LeakyReLU(), |
||||
conv3x3(block_dims[3], block_dims[2]), |
||||
) |
||||
|
||||
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) |
||||
self.layer2_outconv2 = nn.Sequential( |
||||
conv3x3(block_dims[2], block_dims[2]), |
||||
nn.BatchNorm2d(block_dims[2]), |
||||
nn.LeakyReLU(), |
||||
conv3x3(block_dims[2], block_dims[1]), |
||||
) |
||||
|
||||
for m in self.modules(): |
||||
if isinstance(m, nn.Conv2d): |
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
||||
nn.init.constant_(m.weight, 1) |
||||
nn.init.constant_(m.bias, 0) |
||||
|
||||
def _make_layer(self, block, dim, stride=1): |
||||
layer1 = block(self.in_planes, dim, stride=stride) |
||||
layer2 = block(dim, dim, stride=1) |
||||
layers = (layer1, layer2) |
||||
|
||||
self.in_planes = dim |
||||
return nn.Sequential(*layers) |
||||
|
||||
def forward(self, x): |
||||
# ResNet Backbone |
||||
x0 = self.relu(self.bn1(self.conv1(x))) |
||||
x1 = self.layer1(x0) # 1/2 |
||||
x2 = self.layer2(x1) # 1/4 |
||||
x3 = self.layer3(x2) # 1/8 |
||||
x4 = self.layer4(x3) # 1/16 |
||||
|
||||
# FPN |
||||
x4_out = self.layer4_outconv(x4) |
||||
|
||||
x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) |
||||
x3_out = self.layer3_outconv(x3) |
||||
x3_out = self.layer3_outconv2(x3_out+x4_out_2x) |
||||
|
||||
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) |
||||
x2_out = self.layer2_outconv(x2) |
||||
x2_out = self.layer2_outconv2(x2_out+x3_out_2x) |
||||
|
||||
return [x4_out, x2_out] |
@ -0,0 +1,73 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
from einops.einops import rearrange |
||||
|
||||
from .backbone import build_backbone |
||||
from .utils.position_encoding import PositionEncodingSine |
||||
from .loftr_module import LocalFeatureTransformer, FinePreprocess |
||||
from .utils.coarse_matching import CoarseMatching |
||||
from .utils.fine_matching import FineMatching |
||||
|
||||
|
||||
class LoFTR(nn.Module): |
||||
def __init__(self, config): |
||||
super().__init__() |
||||
# Misc |
||||
self.config = config |
||||
|
||||
# Modules |
||||
self.backbone = build_backbone(config) |
||||
self.pos_encoding = PositionEncodingSine(config['coarse']['d_model']) |
||||
self.loftr_coarse = LocalFeatureTransformer(config['coarse']) |
||||
self.coarse_matching = CoarseMatching(config['match_coarse']) |
||||
self.fine_preprocess = FinePreprocess(config) |
||||
self.loftr_fine = LocalFeatureTransformer(config["fine"]) |
||||
self.fine_matching = FineMatching() |
||||
|
||||
def forward(self, data): |
||||
""" |
||||
Update: |
||||
data (dict): { |
||||
'image0': (torch.Tensor): (N, 1, H, W) |
||||
'image1': (torch.Tensor): (N, 1, H, W) |
||||
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position |
||||
'mask1'(optional) : (torch.Tensor): (N, H, W) |
||||
} |
||||
""" |
||||
# 1. Local Feature CNN |
||||
data.update({ |
||||
'bs': data['image0'].size(0), |
||||
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] |
||||
}) |
||||
|
||||
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence |
||||
feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) |
||||
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) |
||||
else: # handle different input shapes |
||||
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) |
||||
|
||||
data.update({ |
||||
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], |
||||
'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] |
||||
}) |
||||
|
||||
# 2. coarse-level loftr module |
||||
# add featmap with positional encoding, then flatten it to sequence [N, HW, C] |
||||
feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') |
||||
feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') |
||||
|
||||
mask_c0 = mask_c1 = None # mask is useful in training |
||||
if 'mask0' in data: |
||||
mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) |
||||
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) |
||||
|
||||
# 3. match coarse-level |
||||
self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) |
||||
|
||||
# 4. fine-level refinement |
||||
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) |
||||
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted |
||||
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) |
||||
|
||||
# 5. match fine-level |
||||
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) |
@ -0,0 +1,2 @@ |
||||
from .transformer import LocalFeatureTransformer |
||||
from .fine_preprocess import FinePreprocess |
@ -0,0 +1,59 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from einops.einops import rearrange, repeat |
||||
|
||||
|
||||
class FinePreprocess(nn.Module): |
||||
def __init__(self, config): |
||||
super().__init__() |
||||
|
||||
self.config = config |
||||
self.cat_c_feat = config['fine_concat_coarse_feat'] |
||||
self.W = self.config['fine_window_size'] |
||||
|
||||
d_model_c = self.config['coarse']['d_model'] |
||||
d_model_f = self.config['fine']['d_model'] |
||||
self.d_model_f = d_model_f |
||||
if self.cat_c_feat: |
||||
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) |
||||
self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) |
||||
|
||||
self._reset_parameters() |
||||
|
||||
def _reset_parameters(self): |
||||
for p in self.parameters(): |
||||
if p.dim() > 1: |
||||
nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") |
||||
|
||||
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): |
||||
W = self.W |
||||
stride = data['hw0_f'][0] // data['hw0_c'][0] |
||||
|
||||
data.update({'W': W}) |
||||
if data['b_ids'].shape[0] == 0: |
||||
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) |
||||
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) |
||||
return feat0, feat1 |
||||
|
||||
# 1. unfold(crop) all local windows |
||||
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) |
||||
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) |
||||
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) |
||||
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) |
||||
|
||||
# 2. select only the predicted matches |
||||
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] |
||||
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] |
||||
|
||||
# option: use coarse-level loftr feature as context: concat and linear |
||||
if self.cat_c_feat: |
||||
feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], |
||||
feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] |
||||
feat_cf_win = self.merge_feat(torch.cat([ |
||||
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] |
||||
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] |
||||
], -1)) |
||||
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) |
||||
|
||||
return feat_f0_unfold, feat_f1_unfold |
@ -0,0 +1,81 @@ |
||||
""" |
||||
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" |
||||
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py |
||||
""" |
||||
|
||||
import torch |
||||
from torch.nn import Module, Dropout |
||||
|
||||
|
||||
def elu_feature_map(x): |
||||
return torch.nn.functional.elu(x) + 1 |
||||
|
||||
|
||||
class LinearAttention(Module): |
||||
def __init__(self, eps=1e-6): |
||||
super().__init__() |
||||
self.feature_map = elu_feature_map |
||||
self.eps = eps |
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
||||
""" Multi-Head linear attention proposed in "Transformers are RNNs" |
||||
Args: |
||||
queries: [N, L, H, D] |
||||
keys: [N, S, H, D] |
||||
values: [N, S, H, D] |
||||
q_mask: [N, L] |
||||
kv_mask: [N, S] |
||||
Returns: |
||||
queried_values: (N, L, H, D) |
||||
""" |
||||
Q = self.feature_map(queries) |
||||
K = self.feature_map(keys) |
||||
|
||||
# set padded position to zero |
||||
if q_mask is not None: |
||||
Q = Q * q_mask[:, :, None, None] |
||||
if kv_mask is not None: |
||||
K = K * kv_mask[:, :, None, None] |
||||
values = values * kv_mask[:, :, None, None] |
||||
|
||||
v_length = values.size(1) |
||||
values = values / v_length # prevent fp16 overflow |
||||
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V |
||||
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) |
||||
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length |
||||
|
||||
return queried_values.contiguous() |
||||
|
||||
|
||||
class FullAttention(Module): |
||||
def __init__(self, use_dropout=False, attention_dropout=0.1): |
||||
super().__init__() |
||||
self.use_dropout = use_dropout |
||||
self.dropout = Dropout(attention_dropout) |
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
||||
""" Multi-head scaled dot-product attention, a.k.a full attention. |
||||
Args: |
||||
queries: [N, L, H, D] |
||||
keys: [N, S, H, D] |
||||
values: [N, S, H, D] |
||||
q_mask: [N, L] |
||||
kv_mask: [N, S] |
||||
Returns: |
||||
queried_values: (N, L, H, D) |
||||
""" |
||||
|
||||
# Compute the unnormalized attention and apply the masks |
||||
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) |
||||
if kv_mask is not None: |
||||
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) |
||||
|
||||
# Compute the attention and the weighted average |
||||
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) |
||||
A = torch.softmax(softmax_temp * QK, dim=2) |
||||
if self.use_dropout: |
||||
A = self.dropout(A) |
||||
|
||||
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) |
||||
|
||||
return queried_values.contiguous() |
@ -0,0 +1,101 @@ |
||||
import copy |
||||
import torch |
||||
import torch.nn as nn |
||||
from .linear_attention import LinearAttention, FullAttention |
||||
|
||||
|
||||
class LoFTREncoderLayer(nn.Module): |
||||
def __init__(self, |
||||
d_model, |
||||
nhead, |
||||
attention='linear'): |
||||
super(LoFTREncoderLayer, self).__init__() |
||||
|
||||
self.dim = d_model // nhead |
||||
self.nhead = nhead |
||||
|
||||
# multi-head attention |
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False) |
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False) |
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False) |
||||
self.attention = LinearAttention() if attention == 'linear' else FullAttention() |
||||
self.merge = nn.Linear(d_model, d_model, bias=False) |
||||
|
||||
# feed-forward network |
||||
self.mlp = nn.Sequential( |
||||
nn.Linear(d_model*2, d_model*2, bias=False), |
||||
nn.ReLU(True), |
||||
nn.Linear(d_model*2, d_model, bias=False), |
||||
) |
||||
|
||||
# norm and dropout |
||||
self.norm1 = nn.LayerNorm(d_model) |
||||
self.norm2 = nn.LayerNorm(d_model) |
||||
|
||||
def forward(self, x, source, x_mask=None, source_mask=None): |
||||
""" |
||||
Args: |
||||
x (torch.Tensor): [N, L, C] |
||||
source (torch.Tensor): [N, S, C] |
||||
x_mask (torch.Tensor): [N, L] (optional) |
||||
source_mask (torch.Tensor): [N, S] (optional) |
||||
""" |
||||
bs = x.size(0) |
||||
query, key, value = x, source, source |
||||
|
||||
# multi-head attention |
||||
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] |
||||
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] |
||||
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) |
||||
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] |
||||
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] |
||||
message = self.norm1(message) |
||||
|
||||
# feed-forward network |
||||
message = self.mlp(torch.cat([x, message], dim=2)) |
||||
message = self.norm2(message) |
||||
|
||||
return x + message |
||||
|
||||
|
||||
class LocalFeatureTransformer(nn.Module): |
||||
"""A Local Feature Transformer (LoFTR) module.""" |
||||
|
||||
def __init__(self, config): |
||||
super(LocalFeatureTransformer, self).__init__() |
||||
|
||||
self.config = config |
||||
self.d_model = config['d_model'] |
||||
self.nhead = config['nhead'] |
||||
self.layer_names = config['layer_names'] |
||||
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) |
||||
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) |
||||
self._reset_parameters() |
||||
|
||||
def _reset_parameters(self): |
||||
for p in self.parameters(): |
||||
if p.dim() > 1: |
||||
nn.init.xavier_uniform_(p) |
||||
|
||||
def forward(self, feat0, feat1, mask0=None, mask1=None): |
||||
""" |
||||
Args: |
||||
feat0 (torch.Tensor): [N, L, C] |
||||
feat1 (torch.Tensor): [N, S, C] |
||||
mask0 (torch.Tensor): [N, L] (optional) |
||||
mask1 (torch.Tensor): [N, S] (optional) |
||||
""" |
||||
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" |
||||
|
||||
for layer, name in zip(self.layers, self.layer_names): |
||||
if name == 'self': |
||||
feat0 = layer(feat0, feat0, mask0, mask0) |
||||
feat1 = layer(feat1, feat1, mask1, mask1) |
||||
elif name == 'cross': |
||||
feat0 = layer(feat0, feat1, mask0, mask1) |
||||
feat1 = layer(feat1, feat0, mask1, mask0) |
||||
else: |
||||
raise KeyError |
||||
|
||||
return feat0, feat1 |
@ -0,0 +1,177 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from einops.einops import rearrange |
||||
|
||||
|
||||
def mask_border(m, b: int, v): |
||||
""" Mask borders with value |
||||
Args: |
||||
m (torch.Tensor): [N, H0, W0, H1, W1] |
||||
b (int) |
||||
v (m.dtype) |
||||
""" |
||||
m[:, :b] = v |
||||
m[:, :, :b] = v |
||||
m[:, :, :, :b] = v |
||||
m[:, :, :, :, :b] = v |
||||
m[:, -b:0] = v |
||||
m[:, :, -b:0] = v |
||||
m[:, :, :, -b:0] = v |
||||
m[:, :, :, :, -b:0] = v |
||||
|
||||
|
||||
def mask_border_with_padding(m, bd, v, p_m0, p_m1): |
||||
m[:, :bd] = v |
||||
m[:, :, :bd] = v |
||||
m[:, :, :, :bd] = v |
||||
m[:, :, :, :, :bd] = v |
||||
|
||||
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() |
||||
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() |
||||
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): |
||||
m[b_idx, h0-bd:] = v |
||||
m[b_idx, :, w0-bd:] = v |
||||
m[b_idx, :, :, h1-bd:] = v |
||||
m[b_idx, :, :, :, w1-bd:] = v |
||||
|
||||
|
||||
class CoarseMatching(nn.Module): |
||||
def __init__(self, config): |
||||
super().__init__() |
||||
self.config = config |
||||
# general config |
||||
self.thr = config['thr'] |
||||
self.border_rm = config['border_rm'] |
||||
|
||||
# we provide 2 options for differentiable matching |
||||
self.match_type = config['match_type'] |
||||
if self.match_type == 'dual_softmax': |
||||
self.temperature = config['dsmax_temperature'] |
||||
elif self.match_type == 'sinkhorn': |
||||
try: |
||||
from .superglue import log_optimal_transport |
||||
except ImportError: |
||||
raise ImportError("download superglue.py first!") |
||||
self.log_optimal_transport = log_optimal_transport |
||||
self.bin_score = nn.Parameter(torch.tensor(config['skh_init_bin_score'], requires_grad=True)) |
||||
self.skh_iters = config['skh_iters'] |
||||
self.skh_prefilter = config['skh_prefilter'] |
||||
else: |
||||
raise NotImplementedError() |
||||
|
||||
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): |
||||
""" |
||||
Args: |
||||
feat0 (torch.Tensor): [N, L, C] |
||||
feat1 (torch.Tensor): [N, S, C] |
||||
data (dict) |
||||
mask_c0 (torch.Tensor): [N, L] (optional) |
||||
mask_c1 (torch.Tensor): [N, S] (optional) |
||||
Update: |
||||
data (dict): { |
||||
'b_ids' (torch.Tensor): [M'], |
||||
'i_ids' (torch.Tensor): [M'], |
||||
'j_ids' (torch.Tensor): [M'], |
||||
'gt_mask' (torch.Tensor): [M'], |
||||
'mkpts0_c' (torch.Tensor): [M, 2], |
||||
'mkpts1_c' (torch.Tensor): [M, 2], |
||||
'mconf' (torch.Tensor): [M]} |
||||
NOTE: M' != M during training. |
||||
""" |
||||
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) |
||||
|
||||
# normalize |
||||
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, [feat_c0, feat_c1]) |
||||
|
||||
if self.match_type == 'dual_softmax': |
||||
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) / self.temperature |
||||
if mask_c0 is not None: |
||||
valid_sim_mask = mask_c0[..., None] * mask_c1[:, None] |
||||
_inf = torch.zeros_like(sim_matrix) |
||||
_inf[~valid_sim_mask.bool()] = -1e9 |
||||
del valid_sim_mask |
||||
sim_matrix += _inf |
||||
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) |
||||
|
||||
elif self.match_type == 'sinkhorn': |
||||
# sinkhorn, dustbin included |
||||
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) |
||||
if mask_c0 is not None: |
||||
sim_matrix[:, :L, :S].masked_fill_(~(mask_c0[..., None] * mask_c1[:, None]).bool(), float('-inf')) |
||||
|
||||
# build uniform prior & use sinkhorn |
||||
log_assign_matrix = self.log_optimal_transport(sim_matrix, self.bin_score, self.skh_iters) |
||||
assign_matrix = log_assign_matrix.exp() |
||||
conf_matrix = assign_matrix[:, :-1, :-1] |
||||
|
||||
# filter prediction with dustbin score (only in evaluation mode) |
||||
if not self.training and self.skh_prefilter: |
||||
filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] |
||||
filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] |
||||
conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 |
||||
conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 |
||||
|
||||
data.update({'conf_matrix': conf_matrix}) |
||||
|
||||
# predict coarse matches from conf_matrix |
||||
data.update(**self.get_coarse_match(conf_matrix, data)) |
||||
|
||||
@torch.no_grad() |
||||
def get_coarse_match(self, conf_matrix, data): |
||||
""" |
||||
Args: |
||||
conf_matrix (torch.Tensor): [N, L, S] |
||||
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] |
||||
Returns: |
||||
coarse_matches (dict): { |
||||
'b_ids' (torch.Tensor): [M'], |
||||
'i_ids' (torch.Tensor): [M'], |
||||
'j_ids' (torch.Tensor): [M'], |
||||
'gt_mask' (torch.Tensor): [M'], |
||||
'm_bids' (torch.Tensor): [M], |
||||
'mkpts0_c' (torch.Tensor): [M, 2], |
||||
'mkpts1_c' (torch.Tensor): [M, 2], |
||||
'mconf' (torch.Tensor): [M]} |
||||
""" |
||||
axes_lengths = {'h0c': data['hw0_c'][0], 'w0c': data['hw0_c'][1], |
||||
'h1c': data['hw1_c'][0], 'w1c': data['hw1_c'][1]} |
||||
# 1. confidence thresholding |
||||
mask = conf_matrix > self.thr |
||||
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', **axes_lengths) |
||||
if 'mask0' not in data: |
||||
mask_border(mask, self.border_rm, False) |
||||
else: |
||||
mask_border_with_padding(mask, self.border_rm, False, data['mask0'], data['mask1']) |
||||
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', **axes_lengths) |
||||
|
||||
# 2. mutual nearest |
||||
mask = mask \ |
||||
* (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ |
||||
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) |
||||
|
||||
# 3. find all valid coarse matches |
||||
# this only works when at most one `True` in each row |
||||
mask_v, all_j_ids = mask.max(dim=2) |
||||
b_ids, i_ids = torch.where(mask_v) |
||||
j_ids = all_j_ids[b_ids, i_ids] |
||||
mconf = conf_matrix[b_ids, i_ids, j_ids] |
||||
|
||||
# These matches select patches that feed into fine-level network |
||||
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} |
||||
|
||||
# 4. Update with matches in original image resolution |
||||
scale = data['hw0_i'][0] / data['hw0_c'][0] |
||||
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale |
||||
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale |
||||
mkpts0_c = torch.stack([i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], dim=1) * scale0 |
||||
mkpts1_c = torch.stack([j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], dim=1) * scale1 |
||||
|
||||
# These matches is the current prediction (for visualization) |
||||
coarse_matches.update({'gt_mask': mconf == 0, |
||||
'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches |
||||
'mkpts0_c': mkpts0_c[mconf != 0], |
||||
'mkpts1_c': mkpts1_c[mconf != 0], |
||||
'mconf': mconf[mconf != 0]}) |
||||
|
||||
return coarse_matches |
@ -0,0 +1,49 @@ |
||||
from yacs.config import CfgNode as CN |
||||
|
||||
|
||||
def lower_config(yacs_cfg): |
||||
if not isinstance(yacs_cfg, CN): |
||||
return yacs_cfg |
||||
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} |
||||
|
||||
|
||||
_CN = CN() |
||||
_CN.BACKBONE_TYPE = 'ResNetFPN' |
||||
_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] |
||||
_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd |
||||
_CN.FINE_CONCAT_COARSE_FEAT = True |
||||
|
||||
# 1. LoFTR-backbone (local feature CNN) config |
||||
_CN.RESNETFPN = CN() |
||||
_CN.RESNETFPN.INITIAL_DIM = 128 |
||||
_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 |
||||
|
||||
# 2. LoFTR-coarse module config |
||||
_CN.COARSE = CN() |
||||
_CN.COARSE.D_MODEL = 256 |
||||
_CN.COARSE.D_FFN = 256 |
||||
_CN.COARSE.NHEAD = 8 |
||||
_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 |
||||
_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] |
||||
|
||||
# 3. Coarse-Matching config |
||||
_CN.MATCH_COARSE = CN() |
||||
_CN.MATCH_COARSE.THR = 0.2 |
||||
_CN.MATCH_COARSE.BORDER_RM = 2 |
||||
_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] |
||||
_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 |
||||
_CN.MATCH_COARSE.SKH_ITERS = 3 |
||||
_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 |
||||
_CN.MATCH_COARSE.SKH_PREFILTER = True |
||||
_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory |
||||
_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock |
||||
|
||||
# 4. LoFTR-fine module config |
||||
_CN.FINE = CN() |
||||
_CN.FINE.D_MODEL = 128 |
||||
_CN.FINE.D_FFN = 128 |
||||
_CN.FINE.NHEAD = 8 |
||||
_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 |
||||
_CN.FINE.ATTENTION = 'linear' |
||||
|
||||
default_cfg = lower_config(_CN) |
@ -0,0 +1,71 @@ |
||||
import math |
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
from kornia.geometry.subpix import dsnt |
||||
from kornia.utils.grid import create_meshgrid |
||||
|
||||
|
||||
class FineMatching(nn.Module): |
||||
"""FineMatching with s2d paradigm""" |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
|
||||
def forward(self, feat_f0, feat_f1, data): |
||||
""" |
||||
Args: |
||||
feat0 (torch.Tensor): [M, WW, C] |
||||
feat1 (torch.Tensor): [M, WW, C] |
||||
data (dict) |
||||
Update: |
||||
data (dict):{ |
||||
'expec_f' (torch.Tensor): [M, 3], |
||||
'mkpts0_f' (torch.Tensor): [M, 2], |
||||
'mkpts1_f' (torch.Tensor): [M, 2]} |
||||
""" |
||||
M, WW, C = feat_f0.shape |
||||
W = int(math.sqrt(WW)) |
||||
scale = data['hw0_i'][0] / data['hw0_f'][0] |
||||
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale |
||||
|
||||
# corner case: if no coarse matches found |
||||
if M == 0: |
||||
assert self.training == False, "M is always >0, when training, see coarse_matching.py" |
||||
# logger.warning('No matches found in coarse-level.') |
||||
data.update({ |
||||
'expec_f': torch.empty(0, 3, device=feat_f0.device), |
||||
'mkpts0_f': data['mkpts0_c'], |
||||
'mkpts1_f': data['mkpts1_c'], |
||||
}) |
||||
return |
||||
|
||||
feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] |
||||
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) |
||||
softmax_temp = 1. / C**.5 |
||||
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) |
||||
|
||||
# compute coordinates from heatmap |
||||
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] |
||||
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] |
||||
|
||||
# compute std over <x, y> |
||||
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] |
||||
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability |
||||
|
||||
# compute absolute kpt coords |
||||
self.get_fine_match(coords_normalized, data) |
||||
|
||||
@torch.no_grad() |
||||
def get_fine_match(self, coords_normed, data): |
||||
W, WW, C, scale = self.W, self.WW, self.C, self.scale |
||||
|
||||
# mkpts0_f and mkpts1_f |
||||
mkpts0_f = data['mkpts0_c'] |
||||
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale |
||||
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] |
||||
|
||||
data.update({ |
||||
"mkpts0_f": mkpts0_f, |
||||
"mkpts1_f": mkpts1_f |
||||
}) |
@ -0,0 +1,35 @@ |
||||
import math |
||||
import torch |
||||
from torch import nn |
||||
|
||||
|
||||
class PositionEncodingSine(nn.Module): |
||||
""" |
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images |
||||
""" |
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256)): |
||||
""" |
||||
Args: |
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels |
||||
""" |
||||
super().__init__() |
||||
|
||||
pe = torch.zeros((d_model, *max_shape)) |
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) |
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) |
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) |
||||
div_term = div_term[:, None, None] # [C//4, 1, 1] |
||||
pe[0::4, :, :] = torch.sin(x_position * div_term) |
||||
pe[1::4, :, :] = torch.cos(x_position * div_term) |
||||
pe[2::4, :, :] = torch.sin(y_position * div_term) |
||||
pe[3::4, :, :] = torch.cos(y_position * div_term) |
||||
|
||||
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] |
||||
|
||||
def forward(self, x): |
||||
""" |
||||
Args: |
||||
x: [N, C, H, W] |
||||
""" |
||||
return x + self.pe[:, :, :x.size(2), :x.size(3)] |
@ -0,0 +1,53 @@ |
||||
import albumentations as A |
||||
|
||||
|
||||
class DarkAug(object): |
||||
""" |
||||
Extreme dark augmentation aiming at Aachen Day-Night |
||||
""" |
||||
|
||||
def __init__(self) -> None: |
||||
self.augmentor = A.Compose([ |
||||
A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), |
||||
A.Blur(p=0.1, blur_limit=(3, 9)), |
||||
A.MotionBlur(p=0.2, blur_limit=(3, 25)), |
||||
A.RandomGamma(p=0.1, gamma_limit=(15, 65)), |
||||
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) |
||||
], p=0.75) |
||||
|
||||
def __call__(self, x): |
||||
return self.augmentor(image=x)['image'] |
||||
|
||||
|
||||
class MobileAug(object): |
||||
""" |
||||
Random augmentations aiming at images of mobile/handhold devices. |
||||
""" |
||||
|
||||
def __init__(self): |
||||
self.augmentor = A.Compose([ |
||||
A.MotionBlur(p=0.25), |
||||
A.ColorJitter(p=0.5), |
||||
A.RandomRain(p=0.1), # random occlusion |
||||
A.RandomSunFlare(p=0.1), |
||||
A.JpegCompression(p=0.25), |
||||
A.ISONoise(p=0.25) |
||||
], p=1.0) |
||||
|
||||
def __call__(self, x): |
||||
return self.augmentor(image=x)['image'] |
||||
|
||||
|
||||
def build_augmentor(method=None, **kwargs): |
||||
if method == 'dark': |
||||
return DarkAug() |
||||
elif method == 'mobile': |
||||
return MobileAug() |
||||
elif method is None: |
||||
return None |
||||
else: |
||||
raise ValueError(f'Invalid augmentation method: {method}') |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
augmentor = build_augmentor('FDA') |
@ -0,0 +1,265 @@ |
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
||||
""" |
||||
[Copied from detectron2] |
||||
This file contains primitives for multi-gpu communication. |
||||
This is useful when doing distributed training. |
||||
""" |
||||
|
||||
import functools |
||||
import logging |
||||
import numpy as np |
||||
import pickle |
||||
import torch |
||||
import torch.distributed as dist |
||||
|
||||
_LOCAL_PROCESS_GROUP = None |
||||
""" |
||||
A torch process group which only includes processes that on the same machine as the current process. |
||||
This variable is set when processes are spawned by `launch()` in "engine/launch.py". |
||||
""" |
||||
|
||||
|
||||
def get_world_size() -> int: |
||||
if not dist.is_available(): |
||||
return 1 |
||||
if not dist.is_initialized(): |
||||
return 1 |
||||
return dist.get_world_size() |
||||
|
||||
|
||||
def get_rank() -> int: |
||||
if not dist.is_available(): |
||||
return 0 |
||||
if not dist.is_initialized(): |
||||
return 0 |
||||
return dist.get_rank() |
||||
|
||||
|
||||
def get_local_rank() -> int: |
||||
""" |
||||
Returns: |
||||
The rank of the current process within the local (per-machine) process group. |
||||
""" |
||||
if not dist.is_available(): |
||||
return 0 |
||||
if not dist.is_initialized(): |
||||
return 0 |
||||
assert _LOCAL_PROCESS_GROUP is not None |
||||
return dist.get_rank(group=_LOCAL_PROCESS_GROUP) |
||||
|
||||
|
||||
def get_local_size() -> int: |
||||
""" |
||||
Returns: |
||||
The size of the per-machine process group, |
||||
i.e. the number of processes per machine. |
||||
""" |
||||
if not dist.is_available(): |
||||
return 1 |
||||
if not dist.is_initialized(): |
||||
return 1 |
||||
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) |
||||
|
||||
|
||||
def is_main_process() -> bool: |
||||
return get_rank() == 0 |
||||
|
||||
|
||||
def synchronize(): |
||||
""" |
||||
Helper function to synchronize (barrier) among all processes when |
||||
using distributed training |
||||
""" |
||||
if not dist.is_available(): |
||||
return |
||||
if not dist.is_initialized(): |
||||
return |
||||
world_size = dist.get_world_size() |
||||
if world_size == 1: |
||||
return |
||||
dist.barrier() |
||||
|
||||
|
||||
@functools.lru_cache() |
||||
def _get_global_gloo_group(): |
||||
""" |
||||
Return a process group based on gloo backend, containing all the ranks |
||||
The result is cached. |
||||
""" |
||||
if dist.get_backend() == "nccl": |
||||
return dist.new_group(backend="gloo") |
||||
else: |
||||
return dist.group.WORLD |
||||
|
||||
|
||||
def _serialize_to_tensor(data, group): |
||||
backend = dist.get_backend(group) |
||||
assert backend in ["gloo", "nccl"] |
||||
device = torch.device("cpu" if backend == "gloo" else "cuda") |
||||
|
||||
buffer = pickle.dumps(data) |
||||
if len(buffer) > 1024 ** 3: |
||||
logger = logging.getLogger(__name__) |
||||
logger.warning( |
||||
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format( |
||||
get_rank(), len(buffer) / (1024 ** 3), device |
||||
) |
||||
) |
||||
storage = torch.ByteStorage.from_buffer(buffer) |
||||
tensor = torch.ByteTensor(storage).to(device=device) |
||||
return tensor |
||||
|
||||
|
||||
def _pad_to_largest_tensor(tensor, group): |
||||
""" |
||||
Returns: |
||||
list[int]: size of the tensor, on each rank |
||||
Tensor: padded tensor that has the max size |
||||
""" |
||||
world_size = dist.get_world_size(group=group) |
||||
assert ( |
||||
world_size >= 1 |
||||
), "comm.gather/all_gather must be called from ranks within the given group!" |
||||
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) |
||||
size_list = [ |
||||
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) |
||||
] |
||||
dist.all_gather(size_list, local_size, group=group) |
||||
|
||||
size_list = [int(size.item()) for size in size_list] |
||||
|
||||
max_size = max(size_list) |
||||
|
||||
# we pad the tensor because torch all_gather does not support |
||||
# gathering tensors of different shapes |
||||
if local_size != max_size: |
||||
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) |
||||
tensor = torch.cat((tensor, padding), dim=0) |
||||
return size_list, tensor |
||||
|
||||
|
||||
def all_gather(data, group=None): |
||||
""" |
||||
Run all_gather on arbitrary picklable data (not necessarily tensors). |
||||
|
||||
Args: |
||||
data: any picklable object |
||||
group: a torch process group. By default, will use a group which |
||||
contains all ranks on gloo backend. |
||||
|
||||
Returns: |
||||
list[data]: list of data gathered from each rank |
||||
""" |
||||
if get_world_size() == 1: |
||||
return [data] |
||||
if group is None: |
||||
group = _get_global_gloo_group() |
||||
if dist.get_world_size(group) == 1: |
||||
return [data] |
||||
|
||||
tensor = _serialize_to_tensor(data, group) |
||||
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group) |
||||
max_size = max(size_list) |
||||
|
||||
# receiving Tensor from all ranks |
||||
tensor_list = [ |
||||
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list |
||||
] |
||||
dist.all_gather(tensor_list, tensor, group=group) |
||||
|
||||
data_list = [] |
||||
for size, tensor in zip(size_list, tensor_list): |
||||
buffer = tensor.cpu().numpy().tobytes()[:size] |
||||
data_list.append(pickle.loads(buffer)) |
||||
|
||||
return data_list |
||||
|
||||
|
||||
def gather(data, dst=0, group=None): |
||||
""" |
||||
Run gather on arbitrary picklable data (not necessarily tensors). |
||||
|
||||
Args: |
||||
data: any picklable object |
||||
dst (int): destination rank |
||||
group: a torch process group. By default, will use a group which |
||||
contains all ranks on gloo backend. |
||||
|
||||
Returns: |
||||
list[data]: on dst, a list of data gathered from each rank. Otherwise, |
||||
an empty list. |
||||
""" |
||||
if get_world_size() == 1: |
||||
return [data] |
||||
if group is None: |
||||
group = _get_global_gloo_group() |
||||
if dist.get_world_size(group=group) == 1: |
||||
return [data] |
||||
rank = dist.get_rank(group=group) |
||||
|
||||
tensor = _serialize_to_tensor(data, group) |
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group) |
||||
|
||||
# receiving Tensor from all ranks |
||||
if rank == dst: |
||||
max_size = max(size_list) |
||||
tensor_list = [ |
||||
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list |
||||
] |
||||
dist.gather(tensor, tensor_list, dst=dst, group=group) |
||||
|
||||
data_list = [] |
||||
for size, tensor in zip(size_list, tensor_list): |
||||
buffer = tensor.cpu().numpy().tobytes()[:size] |
||||
data_list.append(pickle.loads(buffer)) |
||||
return data_list |
||||
else: |
||||
dist.gather(tensor, [], dst=dst, group=group) |
||||
return [] |
||||
|
||||
|
||||
def shared_random_seed(): |
||||
""" |
||||
Returns: |
||||
int: a random number that is the same across all workers. |
||||
If workers need a shared RNG, they can use this shared seed to |
||||
create one. |
||||
|
||||
All workers must call this function, otherwise it will deadlock. |
||||
""" |
||||
ints = np.random.randint(2 ** 31) |
||||
all_ints = all_gather(ints) |
||||
return all_ints[0] |
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True): |
||||
""" |
||||
Reduce the values in the dictionary from all processes so that process with rank |
||||
0 has the reduced results. |
||||
|
||||
Args: |
||||
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. |
||||
average (bool): whether to do average or sum |
||||
|
||||
Returns: |
||||
a dict with the same keys as input_dict, after reduction. |
||||
""" |
||||
world_size = get_world_size() |
||||
if world_size < 2: |
||||
return input_dict |
||||
with torch.no_grad(): |
||||
names = [] |
||||
values = [] |
||||
# sort the keys so that they are consistent across processes |
||||
for k in sorted(input_dict.keys()): |
||||
names.append(k) |
||||
values.append(input_dict[k]) |
||||
values = torch.stack(values, dim=0) |
||||
dist.reduce(values, dst=0) |
||||
if dist.get_rank() == 0 and average: |
||||
# only main process gets accumulated, so only divide by |
||||
# world_size in this case |
||||
values /= world_size |
||||
reduced_dict = {k: v for k, v in zip(names, values)} |
||||
return reduced_dict |
@ -0,0 +1,22 @@ |
||||
import numpy as np |
||||
|
||||
|
||||
# --- PL-DATAMODULE --- |
||||
|
||||
def get_local_split(items: list, world_size: int, rank: int, seed: int): |
||||
""" The local rank only loads a split of dataset. """ |
||||
n_items = len(items) |
||||
items_permute = np.random.RandomState(seed).permutation(items) |
||||
if n_items % world_size == 0: |
||||
padded_items = items_permute |
||||
else: |
||||
padding = np.random.RandomState(seed).choice(items, |
||||
world_size - (n_items % world_size), |
||||
replace=True) |
||||
padded_items = np.concatenate([items_permute, padding]) |
||||
assert len(padded_items) % world_size == 0, \ |
||||
f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' |
||||
n_per_rank = len(padded_items) // world_size |
||||
local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] |
||||
|
||||
return local_items |
@ -0,0 +1,125 @@ |
||||
import cv2 |
||||
import numpy as np |
||||
import h5py |
||||
import torch |
||||
|
||||
|
||||
# --- DATA IO --- |
||||
|
||||
def imread_gray(path, augment_fn=None): |
||||
if augment_fn is None: |
||||
image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) |
||||
else: |
||||
image = cv2.imread(str(path), cv2.IMREAD_COLOR) |
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
||||
image = augment_fn(image) |
||||
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
||||
return image # (h, w) |
||||
|
||||
|
||||
def get_resized_wh(w, h, resize=None): |
||||
if resize is not None: # resize the longer edge |
||||
scale = resize / max(h, w) |
||||
w_new, h_new = int(round(w*scale)), int(round(h*scale)) |
||||
else: |
||||
w_new, h_new = w, h |
||||
return w_new, h_new |
||||
|
||||
|
||||
def get_divisible_wh(w, h, df=None): |
||||
if df is not None: |
||||
w_new, h_new = map(lambda x: int(x // df * df), [w, h]) |
||||
else: |
||||
w_new, h_new = w, h |
||||
return w_new, h_new |
||||
|
||||
|
||||
def pad_bottom_right(inp, pad_size, ret_mask=False): |
||||
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" |
||||
mask = None |
||||
if inp.ndim == 2: |
||||
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) |
||||
padded[:inp.shape[0], :inp.shape[1]] = inp |
||||
if ret_mask: |
||||
mask = np.zeros((pad_size, pad_size), dtype=bool) |
||||
mask[:inp.shape[0], :inp.shape[1]] = True |
||||
elif inp.ndim == 3: |
||||
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) |
||||
padded[:, :inp.shape[1], :inp.shape[2]] = inp |
||||
if ret_mask: |
||||
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) |
||||
mask[:, :inp.shape[1], :inp.shape[2]] = True |
||||
else: |
||||
raise NotImplementedError() |
||||
return padded, mask |
||||
|
||||
|
||||
# --- MEGADEPTH --- |
||||
|
||||
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): |
||||
""" |
||||
Args: |
||||
resize (int, optional): the longer edge of resized images. None for no resize. |
||||
padding (bool): If set to 'True', zero-pad resized images to squared size. |
||||
augment_fn (callable, optional): augments images with pre-defined visual effects |
||||
Returns: |
||||
image (torch.tensor): (1, h, w) |
||||
mask (torch.tensor): (h, w) |
||||
scale (torch.tensor): [w/w_new, h/h_new] |
||||
""" |
||||
# read image |
||||
image = imread_gray(path, augment_fn) |
||||
|
||||
# resize image |
||||
w, h = image.shape[1], image.shape[0] |
||||
w_new, h_new = get_resized_wh(w, h, resize) |
||||
w_new, h_new = get_divisible_wh(w_new, h_new, df) |
||||
|
||||
image = cv2.resize(image, (w_new, h_new)) |
||||
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) |
||||
|
||||
if padding: # padding |
||||
pad_to = max(h_new, w_new) |
||||
image, mask = pad_bottom_right(image, pad_to, ret_mask=True) |
||||
else: |
||||
mask = None |
||||
|
||||
image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized |
||||
mask = torch.from_numpy(mask) |
||||
|
||||
return image, mask, scale |
||||
|
||||
|
||||
def read_megadepth_depth(path, pad_to=None): |
||||
depth = np.array(h5py.File(path, 'r')['depth']) |
||||
if pad_to is not None: |
||||
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) |
||||
depth = torch.from_numpy(depth).float() # (h, w) |
||||
return depth |
||||
|
||||
|
||||
# --- ScanNet --- |
||||
|
||||
def read_scannet_gray(path, resize=(640, 480), augment_fn=None): |
||||
""" |
||||
Args: |
||||
resize (tuple): align image to depthmap, in (w, h). |
||||
augment_fn (callable, optional): augments images with pre-defined visual effects |
||||
Returns: |
||||
image (torch.tensor): (1, h, w) |
||||
mask (torch.tensor): (h, w) |
||||
scale (torch.tensor): [w/w_new, h/h_new] |
||||
""" |
||||
# read and resize image |
||||
image = imread_gray(path, augment_fn) |
||||
image = cv2.resize(image, resize) |
||||
|
||||
# (h, w) -> (1, h, w) and normalized |
||||
image = torch.from_numpy(image).float()[None] / 255 |
||||
return image |
||||
|
||||
|
||||
def read_scannet_depth(path): |
||||
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) / 1000 |
||||
depth = torch.from_numpy(depth).float() # (h, w) |
||||
return depth |
@ -0,0 +1,193 @@ |
||||
import torch |
||||
import cv2 |
||||
import numpy as np |
||||
from collections import OrderedDict |
||||
from loguru import logger |
||||
from kornia.geometry.epipolar import numeric |
||||
from kornia.geometry.conversions import convert_points_to_homogeneous |
||||
|
||||
|
||||
# --- METRICS --- |
||||
|
||||
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): |
||||
# angle error between 2 vectors |
||||
t_gt = T_0to1[:3, 3] |
||||
n = np.linalg.norm(t) * np.linalg.norm(t_gt) |
||||
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) |
||||
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity |
||||
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging |
||||
t_err = 0 |
||||
|
||||
# angle error between 2 rotation matrices |
||||
R_gt = T_0to1[:3, :3] |
||||
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 |
||||
cos = np.clip(cos, -1., 1.) # handle numercial errors |
||||
R_err = np.rad2deg(np.abs(np.arccos(cos))) |
||||
|
||||
return t_err, R_err |
||||
|
||||
|
||||
def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): |
||||
"""Squared symmetric epipolar distance. |
||||
This can be seen as a biased estimation of the reprojection error. |
||||
Args: |
||||
pts0 (torch.Tensor): [N, 2] |
||||
E (torch.Tensor): [3, 3] |
||||
""" |
||||
pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] |
||||
pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] |
||||
pts0 = convert_points_to_homogeneous(pts0) |
||||
pts1 = convert_points_to_homogeneous(pts1) |
||||
|
||||
Ep0 = pts0 @ E.T # [N, 3] |
||||
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] |
||||
Etp1 = pts1 @ E # [N, 3] |
||||
|
||||
d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N |
||||
return d |
||||
|
||||
|
||||
def compute_symmetrical_epipolar_errors(data): |
||||
""" |
||||
Update: |
||||
data (dict):{"epi_errs": [M]} |
||||
""" |
||||
Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) |
||||
E_mat = Tx @ data['T_0to1'][:, :3, :3] |
||||
|
||||
m_bids = data['m_bids'] |
||||
pts0 = data['mkpts0_f'] |
||||
pts1 = data['mkpts1_f'] |
||||
|
||||
epi_errs = [] |
||||
for bs in range(Tx.size(0)): |
||||
mask = m_bids == bs |
||||
epi_errs.append( |
||||
symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) |
||||
epi_errs = torch.cat(epi_errs, dim=0) |
||||
|
||||
data.update({'epi_errs': epi_errs}) |
||||
|
||||
|
||||
def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): |
||||
if len(kpts0) < 5: |
||||
return None |
||||
# normalize keypoints |
||||
kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] |
||||
kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] |
||||
|
||||
# normalize ransac threshold |
||||
ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) |
||||
|
||||
# compute pose with cv2 |
||||
E, mask = cv2.findEssentialMat( |
||||
kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) |
||||
if E is None: |
||||
print("\nE is None while trying to recover pose.\n") |
||||
return None |
||||
|
||||
# recover pose from E |
||||
best_num_inliers = 0 |
||||
ret = None |
||||
for _E in np.split(E, len(E) / 3): |
||||
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) |
||||
if n > best_num_inliers: |
||||
ret = (R, t[:, 0], mask.ravel() > 0) |
||||
best_num_inliers = n |
||||
|
||||
return ret |
||||
|
||||
|
||||
def compute_pose_errors(data, config): |
||||
""" |
||||
Update: |
||||
data (dict):{ |
||||
"R_errs" List[float]: [N] |
||||
"t_errs" List[float]: [N] |
||||
"inliers" List[np.ndarray]: [N] |
||||
} |
||||
""" |
||||
pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 |
||||
conf = config.TRAINER.RANSAC_CONF # 0.99999 |
||||
data.update({'R_errs': [], 't_errs': [], 'inliers': []}) |
||||
|
||||
m_bids = data['m_bids'].cpu().numpy() |
||||
pts0 = data['mkpts0_f'].cpu().numpy() |
||||
pts1 = data['mkpts1_f'].cpu().numpy() |
||||
K0 = data['K0'].cpu().numpy() |
||||
K1 = data['K1'].cpu().numpy() |
||||
T_0to1 = data['T_0to1'].cpu().numpy() |
||||
|
||||
for bs in range(K0.shape[0]): |
||||
mask = m_bids == bs |
||||
ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) |
||||
|
||||
if ret is None: |
||||
data['R_errs'].append(np.inf) |
||||
data['t_errs'].append(np.inf) |
||||
data['inliers'].append(np.array([]).astype(np.bool)) |
||||
else: |
||||
R, t, inliers = ret |
||||
t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) |
||||
data['R_errs'].append(R_err) |
||||
data['t_errs'].append(t_err) |
||||
data['inliers'].append(inliers) |
||||
|
||||
|
||||
# --- METRIC AGGREGATION --- |
||||
|
||||
def error_auc(errors, thresholds): |
||||
""" |
||||
Args: |
||||
errors (list): [N,] |
||||
thresholds (list) |
||||
""" |
||||
errors = [0] + sorted(list(errors)) |
||||
recall = list(np.linspace(0, 1, len(errors))) |
||||
|
||||
aucs = [] |
||||
thresholds = [5, 10, 20] |
||||
for thr in thresholds: |
||||
last_index = np.searchsorted(errors, thr) |
||||
y = recall[:last_index] + [recall[last_index-1]] |
||||
x = errors[:last_index] + [thr] |
||||
aucs.append(np.trapz(y, x) / thr) |
||||
|
||||
return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} |
||||
|
||||
|
||||
def epidist_prec(errors, thresholds, ret_dict=False): |
||||
precs = [] |
||||
for thr in thresholds: |
||||
prec_ = [] |
||||
for errs in errors: |
||||
correct_mask = errs < thr |
||||
prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) |
||||
precs.append(np.mean(prec_) if len(prec_) > 0 else 0) |
||||
if ret_dict: |
||||
return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} |
||||
else: |
||||
return precs |
||||
|
||||
|
||||
def aggregate_metrics(metrics, epi_err_thr=5e-4): |
||||
""" Aggregate metrics for the whole dataset: |
||||
(This method should be called once per dataset) |
||||
1. AUC of the pose error (angular) at the threshold [5, 10, 20] |
||||
2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) |
||||
""" |
||||
# filter duplicates |
||||
unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) |
||||
unq_ids = list(unq_ids.values()) |
||||
logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') |
||||
|
||||
# pose auc |
||||
angular_thresholds = [5, 10, 20] |
||||
pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] |
||||
aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) |
||||
|
||||
# matching precision |
||||
dist_thresholds = [epi_err_thr] |
||||
precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) |
||||
|
||||
return {**aucs, **precs} |
@ -0,0 +1,41 @@ |
||||
from loguru import logger |
||||
from yacs.config import CfgNode as CN |
||||
from itertools import chain |
||||
|
||||
|
||||
def lower_config(yacs_cfg): |
||||
if not isinstance(yacs_cfg, CN): |
||||
return yacs_cfg |
||||
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} |
||||
|
||||
|
||||
def upper_config(dict_cfg): |
||||
if not isinstance(dict_cfg, dict): |
||||
return dict_cfg |
||||
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} |
||||
|
||||
|
||||
def log_on(condition, message, level): |
||||
if condition: |
||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] |
||||
logger.log(level, message) |
||||
|
||||
|
||||
def flattenList(x): |
||||
return list(chain(*x)) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
_CN = CN() |
||||
_CN.A = CN() |
||||
_CN.A.AA = CN() |
||||
_CN.A.AA.AAA = CN() |
||||
_CN.A.AA.AAA.AAAA = "AAAAA" |
||||
|
||||
_CN.B = CN() |
||||
_CN.B.BB = CN() |
||||
_CN.B.BB.BBB = CN() |
||||
_CN.B.BB.BBB.BBBB = "BBBBB" |
||||
|
||||
print(lower_config(_CN)) |
||||
print(lower_config(_CN.A)) |
@ -0,0 +1,50 @@ |
||||
import numpy as np |
||||
import matplotlib.pyplot as plt |
||||
import matplotlib |
||||
|
||||
|
||||
# --- VISUALIZATION --- |
||||
|
||||
def make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=[], path=None): |
||||
# draw image pair |
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=75) |
||||
axes[0].imshow(img0, cmap='gray') |
||||
axes[1].imshow(img1, cmap='gray') |
||||
for i in range(2): # clear all frames |
||||
axes[i].get_yaxis().set_ticks([]) |
||||
axes[i].get_xaxis().set_ticks([]) |
||||
for spine in axes[i].spines.values(): |
||||
spine.set_visible(False) |
||||
plt.tight_layout(pad=1) |
||||
|
||||
# draw matches |
||||
fig.canvas.draw() |
||||
transFigure = fig.transFigure.inverted() |
||||
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) |
||||
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) |
||||
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), |
||||
transform=fig.transFigure, c=color[i], linewidth=1) for i in range(len(mkpts0))] |
||||
|
||||
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) |
||||
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) |
||||
|
||||
# put txts |
||||
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' |
||||
fig.text( |
||||
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, |
||||
fontsize=15, va='top', ha='left', color=txt_color) |
||||
plt.tight_layout(pad=1) |
||||
|
||||
# save or return figure |
||||
if path: |
||||
plt.savefig(str(path), bbox_inches='tight', pad_inches=0) |
||||
plt.close() |
||||
else: |
||||
return fig |
||||
|
||||
|
||||
def error_colormap(err, thr, alpha=1.0): |
||||
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" |
||||
x = 1 - np.clip(err / (thr * 2), 0, 1) |
||||
return np.clip( |
||||
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) |
@ -0,0 +1,40 @@ |
||||
import torch |
||||
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler |
||||
from contextlib import contextmanager |
||||
from pytorch_lightning.utilities import rank_zero_only |
||||
|
||||
|
||||
class InferenceProfiler(SimpleProfiler): |
||||
""" |
||||
This profiler records duration of actions with cuda.synchronize() |
||||
Use this in test time. |
||||
""" |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
self.start = rank_zero_only(self.start) |
||||
self.stop = rank_zero_only(self.stop) |
||||
self.summary = rank_zero_only(self.summary) |
||||
|
||||
@contextmanager |
||||
def profile(self, action_name: str) -> None: |
||||
try: |
||||
torch.cuda.synchronize() |
||||
self.start(action_name) |
||||
yield action_name |
||||
finally: |
||||
torch.cuda.synchronize() |
||||
self.stop(action_name) |
||||
|
||||
|
||||
def build_profiler(name): |
||||
if name == 'inference': |
||||
return InferenceProfiler() |
||||
elif name == 'pytorch': |
||||
from pytorch_lightning.profiler import PyTorchProfiler |
||||
# TODO: this profiler will be introduced after upgrading pl dependency to 1.3.0 @zehong |
||||
return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) |
||||
elif name is None: |
||||
return PassThroughProfiler() |
||||
else: |
||||
raise ValueError(f'Invalid profiler: {name}') |
@ -0,0 +1,68 @@ |
||||
import pytorch_lightning as pl |
||||
import argparse |
||||
import pprint |
||||
from loguru import logger as loguru_logger |
||||
|
||||
from src.config.default import get_cfg_defaults |
||||
from src.utils.profiler import build_profiler |
||||
|
||||
from src.lightning.data import MultiSceneDataModule |
||||
from src.lightning.lightning_loftr import PL_LoFTR |
||||
|
||||
|
||||
def parse_args(): |
||||
# init a costum parser which will be added into pl.Trainer parser |
||||
# check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags |
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
||||
parser.add_argument( |
||||
'data_cfg_path', type=str, help='data config path') |
||||
parser.add_argument( |
||||
'main_cfg_path', type=str, help='main config path') |
||||
parser.add_argument( |
||||
'--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') |
||||
parser.add_argument( |
||||
'--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") |
||||
parser.add_argument( |
||||
'--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') |
||||
parser.add_argument( |
||||
'--batch_size', type=int, default=1, help='batch_size per gpu') |
||||
parser.add_argument( |
||||
'--num_workers', type=int, default=2) |
||||
parser.add_argument( |
||||
'--thr', type=float, default=None, help='modify the coarse-level matching threshold.') |
||||
|
||||
parser = pl.Trainer.add_argparse_args(parser) |
||||
return parser.parse_args() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
# parse arguments |
||||
args = parse_args() |
||||
pprint.pprint(vars(args)) |
||||
|
||||
# init default-cfg and merge it with the main- and data-cfg |
||||
config = get_cfg_defaults() |
||||
config.merge_from_file(args.main_cfg_path) |
||||
config.merge_from_file(args.data_cfg_path) |
||||
pl.seed_everything(config.TRAINER.SEED) # reproducibility |
||||
|
||||
# tune when testing |
||||
if args.thr is not None: |
||||
config.LOFTR.MATCH_COARSE.THR = args.thr |
||||
|
||||
loguru_logger.info(f"Args and config initialized!") |
||||
|
||||
# lightning module |
||||
profiler = build_profiler(args.profiler_name) |
||||
model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) |
||||
loguru_logger.info(f"LoFTR-lightning initialized!") |
||||
|
||||
# lightning data |
||||
data_module = MultiSceneDataModule(args, config) |
||||
loguru_logger.info(f"DataModule initialized!") |
||||
|
||||
# lightning trainer |
||||
trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) |
||||
|
||||
loguru_logger.info(f"Start testing!") |
||||
trainer.test(model, datamodule=data_module, verbose=False) |