Add LoFTR training.

chiebot
YuAng 3 years ago committed by YuAng
parent 348ad897d4
commit 4feac496c1
  1. 10
      .gitignore
  2. 3
      .gitmodules
  3. 6
      README.md
  4. 6
      configs/data/base.py
  5. 3
      configs/data/debug/.gitignore
  6. 22
      configs/data/megadepth_trainval_640.py
  7. 22
      configs/data/megadepth_trainval_840.py
  8. 17
      configs/data/scannet_trainval.py
  9. 3
      configs/loftr/indoor/debug/.gitignore
  10. 2
      configs/loftr/indoor/loftr_ds.py
  11. 7
      configs/loftr/indoor/loftr_ds_dense.py
  12. 2
      configs/loftr/indoor/loftr_ot.py
  13. 7
      configs/loftr/indoor/loftr_ot_dense.py
  14. 3
      configs/loftr/outdoor/debug/.gitignore
  15. 15
      configs/loftr/outdoor/loftr_ds.py
  16. 16
      configs/loftr/outdoor/loftr_ds_dense.py
  17. 15
      configs/loftr/outdoor/loftr_ot.py
  18. 16
      configs/loftr/outdoor/loftr_ot_dense.py
  19. 4
      data/megadepth/index/.gitignore
  20. 4
      data/megadepth/test/.gitignore
  21. 4
      data/megadepth/train/.gitignore
  22. 3
      data/scannet/index/.gitignore
  23. BIN
      data/scannet/intrinsics.npz
  24. 1
      data/scannet/test
  25. 1
      data/scannet/train
  26. 73
      docs/TRAINING.md
  27. 3
      environment.yaml
  28. 197
      notebooks/demo_single_pair.ipynb
  29. 2
      requirements.txt
  30. 3
      scripts/reproduce_train/debug/.gitignore
  31. 33
      scripts/reproduce_train/indoor_ds.sh
  32. 33
      scripts/reproduce_train/indoor_ot.sh
  33. 35
      scripts/reproduce_train/outdoor_ds.sh
  34. 35
      scripts/reproduce_train/outdoor_ot.sh
  35. 55
      src/config/default.py
  36. 11
      src/datasets/megadepth.py
  37. 77
      src/datasets/sampler.py
  38. 46
      src/datasets/scannet.py
  39. 209
      src/lightning/data.py
  40. 174
      src/lightning/lightning_loftr.py
  41. 123
      src/loftr/utils/coarse_matching.py
  42. 3
      src/loftr/utils/fine_matching.py
  43. 54
      src/loftr/utils/geometry.py
  44. 151
      src/loftr/utils/supervision.py
  45. 192
      src/losses/loftr_loss.py
  46. 42
      src/optimizers/__init__.py
  47. 2
      src/utils/augment.py
  48. 5
      src/utils/dataloader.py
  49. 66
      src/utils/dataset.py
  50. 88
      src/utils/misc.py
  51. 122
      src/utils/plotting.py
  52. 1
      src/utils/profiler.py
  53. 1
      third_party/SuperGluePretrainedNetwork
  54. 120
      train.py

10
.gitignore vendored

@ -14,3 +14,13 @@ demo/*.mp4
demo/demo_images/ demo/demo_images/
src/loftr/utils/superglue.py src/loftr/utils/superglue.py
demo/utils.py demo/utils.py
notebooks/QccDayNight.ipynb
notebooks/westlake.ipynb
assets/westlake
assets/qcc_pairs.txt
configs/.petrel*
tools/draw_QccDayNights.py
scripts/slurm/
scripts/sbatch_submit.sh

3
.gitmodules vendored

@ -0,0 +1,3 @@
[submodule "third_party/SuperGluePretrainedNetwork"]
path = third_party/SuperGluePretrainedNetwork
url = git@github.com:magicleap/SuperGluePretrainedNetwork.git

@ -12,7 +12,7 @@
- [x] Inference code and pretrained models (DS and OT) (2021-4-7) - [x] Inference code and pretrained models (DS and OT) (2021-4-7)
- [x] Code for reproducing the test-set results (2021-4-7) - [x] Code for reproducing the test-set results (2021-4-7)
- [x] Webcam demo to reproduce the result shown in the GIF above (2021-4-13) - [x] Webcam demo to reproduce the result shown in the GIF above (2021-4-13)
- [ ] Training code and training data preparation (expected 2021-6-10) - [x] Training code and training data preparation (expected 2021-6-10)
The entire codebase for data pre-processing, training and validation is under major refactoring and will be released around June. The entire codebase for data pre-processing, training and validation is under major refactoring and will be released around June.
Please subscribe to [this discussion thread](https://github.com/zju3dv/LoFTR/discussions/2) if you wish to be notified of the code release. Please subscribe to [this discussion thread](https://github.com/zju3dv/LoFTR/discussions/2) if you wish to be notified of the code release.
@ -177,6 +177,10 @@ Out[19]: 1684276
`data['score']` is the overlapping score defined in [SuperGlue](https://arxiv.org/pdf/1911.11763) (Page 12). `data['score']` is the overlapping score defined in [SuperGlue](https://arxiv.org/pdf/1911.11763) (Page 12).
</details> </details>
## Training
See [Training LoFTR](./docs/TRAINING.md) for more details.
## Citation ## Citation
If you find this code useful for your research, please use the following BibTeX entry. If you find this code useful for your research, please use the following BibTeX entry.

@ -10,22 +10,26 @@ _CN.TRAINER = CN()
# training data config # training data config
_CN.DATASET.TRAIN_DATA_ROOT = None _CN.DATASET.TRAIN_DATA_ROOT = None
_CN.DATASET.TRAIN_POSE_ROOT = None
_CN.DATASET.TRAIN_NPZ_ROOT = None _CN.DATASET.TRAIN_NPZ_ROOT = None
_CN.DATASET.TRAIN_LIST_PATH = None _CN.DATASET.TRAIN_LIST_PATH = None
_CN.DATASET.TRAIN_INTRINSIC_PATH = None _CN.DATASET.TRAIN_INTRINSIC_PATH = None
# validation set config # validation set config
_CN.DATASET.VAL_DATA_ROOT = None _CN.DATASET.VAL_DATA_ROOT = None
_CN.DATASET.VAL_POSE_ROOT = None
_CN.DATASET.VAL_NPZ_ROOT = None _CN.DATASET.VAL_NPZ_ROOT = None
_CN.DATASET.VAL_LIST_PATH = None _CN.DATASET.VAL_LIST_PATH = None
_CN.DATASET.VAL_INTRINSIC_PATH = None _CN.DATASET.VAL_INTRINSIC_PATH = None
# testing data config # testing data config
_CN.DATASET.TEST_DATA_ROOT = None _CN.DATASET.TEST_DATA_ROOT = None
_CN.DATASET.TEST_POSE_ROOT = None
_CN.DATASET.TEST_NPZ_ROOT = None _CN.DATASET.TEST_NPZ_ROOT = None
_CN.DATASET.TEST_LIST_PATH = None _CN.DATASET.TEST_LIST_PATH = None
_CN.DATASET.TEST_INTRINSIC_PATH = None _CN.DATASET.TEST_INTRINSIC_PATH = None
# dataset config # dataset config
_CN.DATASET.MIN_OVERLAP_SCORE = 0.4 _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
cfg = _CN cfg = _CN

@ -0,0 +1,3 @@
*
*/
!.gitignore

@ -0,0 +1,22 @@
from configs.data.base import cfg
TRAIN_BASE_PATH = "data/megadepth/index"
cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
TEST_BASE_PATH = "data/megadepth/index"
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
# 368 scenes in total for MegaDepth
# (with difficulty balanced (further split each scene to 3 sub-scenes))
cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
cfg.DATASET.MGDPT_IMG_RESIZE = 640 # for training on 11GB mem GPUs

@ -0,0 +1,22 @@
from configs.data.base import cfg
TRAIN_BASE_PATH = "data/megadepth/index"
cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
TEST_BASE_PATH = "data/megadepth/index"
cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
# 368 scenes in total for MegaDepth
# (with difficulty balanced (further split each scene to 3 sub-scenes))
cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
cfg.DATASET.MGDPT_IMG_RESIZE = 840 # for training on 32GB meme GPUs

@ -0,0 +1,17 @@
from configs.data.base import cfg
TRAIN_BASE_PATH = "data/scannet/index"
cfg.DATASET.TRAINVAL_DATA_SOURCE = "ScanNet"
cfg.DATASET.TRAIN_DATA_ROOT = "data/scannet/train"
cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_data/train"
cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/scene_data/train_list/scannet_all.txt"
cfg.DATASET.TRAIN_INTRINSIC_PATH = f"{TRAIN_BASE_PATH}/intrinsics.npz"
TEST_BASE_PATH = "assets/scannet_test_1500"
cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val

@ -0,0 +1,3 @@
*
*/
!.gitignore

@ -1,3 +1,5 @@
from src.config.default import _CN as cfg from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]

@ -0,0 +1,7 @@
from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]

@ -1,3 +1,5 @@
from src.config.default import _CN as cfg from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'
cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]

@ -0,0 +1,7 @@
from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]

@ -0,0 +1,3 @@
*
*/
!.gitignore

@ -0,0 +1,15 @@
from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3

@ -0,0 +1,16 @@
from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3

@ -0,0 +1,15 @@
from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'
cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3

@ -0,0 +1,16 @@
from src.config.default import _CN as cfg
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn'
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
cfg.TRAINER.CANONICAL_LR = 8e-3
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
cfg.TRAINER.WARMUP_RATIO = 0.1
cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
# pose estimation
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
cfg.TRAINER.OPTIMIZER = "adamw"
cfg.TRAINER.ADAMW_DECAY = 0.1
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3

@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

@ -0,0 +1,3 @@
*
*/
!.gitignore

Binary file not shown.

@ -0,0 +1 @@
/mnt/lustre/share/3dv/dataset/scannet/scannet_1500_testset

@ -0,0 +1 @@
/mnt/lustre/share/3dv/dataset/scannet/out/output

@ -0,0 +1,73 @@
# Traininig LoFTR
## Dataset setup
Generally, two parts of data are needed for training LoFTR, the original dataset, i.e., ScanNet and MegaDepth, and the offline generated dataset indices. The dataset indices store scenes, image pairs, and other metadata within each dataset used for training/validation/testing. For the MegaDepth dataset, the relative poses between images used for training are directly cached in the indexing files. However, the relative poses of ScanNet image pairs are not stored due to the enormous resulting file size.
**Download the dataset indices**
You can download the required dataset indices from the [following link](https://drive.google.com/drive/folders/1DOcOPZb3-5cWxLqn256AhwUVjBPifhuf).
After downloading, unzip the required files.
```shell
unzip downloaded-file.zip
# extract dataset indices
tar xf train-data/megadepth_indices.tar
tar xf train-data/scannet_indices.tar
# extract testing data (optional)
tar xf testdata/megadepth_test_1500.tar
tar xf testdata/scannet_test_1500.tar
```
**Build the dataset symlinks**
We symlink the datasets to the /data directory under the main LoFTR project directory.
> NOTE: For the ScanNet dataset, we use the [python exported data](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python),
instead of the [c++ exported one](https://github.com/ScanNet/ScanNet/tree/master/SensReader/c%2B%2B).
```shell
# scannet
# -- # train and test dataset
ln -s /path/to/scannet_train/* /path/to/LoFTR/data/scannet/train
ln -s /path/to/scannet_test/* /path/to/LoFTR/data/scannet/test
# -- # dataset indices
ln -s /path/to/scannet_indices/* /path/to/LoFTR/data/scannet/index
# megadepth
# -- # train and test dataset (train and test share the same dataset)
ln -s /path/to/megadepth/Undistorted_SfM/* /path/to/LoFTR/data/megadepth/train
ln -s /path/to/megadepth/Undistorted_SfM/* /path/to/LoFTR/data/megadepth/test
# -- # dataset indices
ln -s /path/to/megadepth_indices/* /path/to/LoFTR/data/megadepth/index
```
## Training
We provide training scripts of ScanNet and MegaDepth. The results in the LoFTR paper can be reproduced with 32/64 GPUs with at least 11GB of RAM for ScanNet, and 8/16 GPUs with at least 24GB of RAM for MegaDepth. For a different setup (e.g., training with 4 gpus on ScanNet), we scale the learning rate and its warm-up linearly, but the final evaluation results might vary due to the different batch size & learning rate used. Thus the reproduction of results in our paper is not guaranteed.
Training scripts of the optimal-transport matcher end with "_ot" and ones of the dual-softmax matcher end with "_ds".
The released training scripts use smaller setups comparing to ones used for training the released models. You could manually scale the setup (e.g., using 32 gpus instead of 4) to reproduce our results.
### Training on ScanNet
``` shell
scripts/reproduce_train/indoor_ds.sh
```
> NOTE: It uses 4 gpus only. Reproduction of paper results is not guaranteed under this setup.
### Training on MegaDepth
``` shell
scripts/reproduce_train/outdoor_ds.sh
```
> NOTE: It uses 4 gpus only, with smaller image sizes of 640x640. Reproduction of paper results is not guaranteed under this setup.
## Updated Training Strategy
In the released training code, we use a slightly modified version of the coarse-level training supervision comparing to the one described in our paper.
For example, as described in our paper, we only supervise the ground-truth positive matches when training the dual-softmax model. However, the entire confidence matrix produced by the dual-softmax matcher is supervised by default in the released code, regardless of the use of softmax operators. This implementation is counter-intuitive and unusual but leads to better evaluation results on estimating relative camera poses. The same phenomenon applies to the optimal-transport matcher version as well. Note that we don't supervise the dustbin rows and columns under the dense supervision setup.
> NOTE: To use the sparse supervision described in our paper, set `_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = False`.

@ -7,8 +7,7 @@ channels:
dependencies: dependencies:
- python=3.8 - python=3.8
- cudatoolkit=10.2 - cudatoolkit=10.2
- pytorch=1.8.0 - pytorch=1.8.1
- pytorch-lightning<=1.1.8 # https://github.com/PyTorchLightning/pytorch-lightning/issues/6318
- pip - pip
- pip: - pip:
- -r file:requirements.txt - -r file:requirements.txt

File diff suppressed because one or more lines are too long

@ -12,3 +12,5 @@ ipython
jupyterlab jupyterlab
matplotlib matplotlib
h5py==3.1.0 h5py==3.1.0
pytorch-lightning==1.3.5
joblib>=1.0.1

@ -0,0 +1,3 @@
*
*/
!.gitignore

@ -0,0 +1,33 @@
#!/bin/bash -l
SCRIPTPATH=$(dirname $(readlink -f "$0"))
PROJECT_DIR="${SCRIPTPATH}/../../"
# conda activate loftr
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR
data_cfg_path="configs/data/scannet_trainval.py"
main_cfg_path="configs/loftr/indoor/loftr_ds_dense.py"
n_nodes=1
n_gpus_per_node=4
torch_num_workers=4
batch_size=1
pin_memory=true
exp_name="indoor-ds-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
python -u ./train.py \
${data_cfg_path} \
${main_cfg_path} \
--exp_name=${exp_name} \
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
--batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
--check_val_every_n_epoch=1 \
--log_every_n_steps=100 \
--flush_logs_every_n_steps=100 \
--limit_val_batches=1. \
--num_sanity_val_steps=10 \
--benchmark=True \
--max_epochs=30 \
--parallel_load_data

@ -0,0 +1,33 @@
#!/bin/bash -l
SCRIPTPATH=$(dirname $(readlink -f "$0"))
PROJECT_DIR="${SCRIPTPATH}/../../"
# conda activate loftr
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR
data_cfg_path="configs/data/scannet_trainval.py"
main_cfg_path="configs/loftr/indoor/loftr_ot_dense.py"
n_nodes=1
n_gpus_per_node=4
torch_num_workers=4
batch_size=1
pin_memory=true
exp_name="indoor-ot-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
python -u ./train.py \
${data_cfg_path} \
${main_cfg_path} \
--exp_name=${exp_name} \
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
--batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
--check_val_every_n_epoch=1 \
--log_every_n_steps=100 \
--flush_logs_every_n_steps=100 \
--limit_val_batches=1. \
--num_sanity_val_steps=10 \
--benchmark=True \
--max_epochs=30 \
--parallel_load_data

@ -0,0 +1,35 @@
#!/bin/bash -l
SCRIPTPATH=$(dirname $(readlink -f "$0"))
PROJECT_DIR="${SCRIPTPATH}/../../"
# conda activate loftr
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR
TRAIN_IMG_SIZE=640
# to reproduced the results in our paper, please use:
# TRAIN_IMG_SIZE=840
data_cfg_path="configs/data/megadepth_trainval_${TRAIN_IMG_SIZE}.py"
main_cfg_path="configs/loftr/outdoor/loftr_ds_dense.py"
n_nodes=1
n_gpus_per_node=4
torch_num_workers=4
batch_size=1
pin_memory=true
exp_name="outdoor-ds-${TRAIN_IMG_SIZE}-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
python -u ./train.py \
${data_cfg_path} \
${main_cfg_path} \
--exp_name=${exp_name} \
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
--batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
--check_val_every_n_epoch=1 \
--log_every_n_steps=1 \
--flush_logs_every_n_steps=1 \
--limit_val_batches=1. \
--num_sanity_val_steps=10 \
--benchmark=True \
--max_epochs=30

@ -0,0 +1,35 @@
#!/bin/bash -l
SCRIPTPATH=$(dirname $(readlink -f "$0"))
PROJECT_DIR="${SCRIPTPATH}/../../"
# conda activate loftr
export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
cd $PROJECT_DIR
TRAIN_IMG_SIZE=640
# to reproduced the results in our paper, please use:
# TRAIN_IMG_SIZE=840
data_cfg_path="configs/data/megadepth_trainval_${TRAIN_IMG_SIZE}.py"
main_cfg_path="configs/loftr/outdoor/loftr_ot_dense.py"
n_nodes=1
n_gpus_per_node=4
torch_num_workers=4
batch_size=1
pin_memory=true
exp_name="outdoor-ot-${TRAIN_IMG_SIZE}-bs=$(($n_gpus_per_node * $n_nodes * $batch_size))"
python -u ./train.py \
${data_cfg_path} \
${main_cfg_path} \
--exp_name=${exp_name} \
--gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
--batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
--check_val_every_n_epoch=1 \
--log_every_n_steps=1 \
--flush_logs_every_n_steps=1 \
--limit_val_batches=1. \
--num_sanity_val_steps=10 \
--benchmark=True \
--max_epochs=30

@ -30,8 +30,9 @@ _CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3 _CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3
_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 _CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False _CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False
_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory _CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock _CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
# 4. LoFTR-fine module config # 4. LoFTR-fine module config
_CN.LOFTR.FINE = CN() _CN.LOFTR.FINE = CN()
@ -41,6 +42,25 @@ _CN.LOFTR.FINE.NHEAD = 8
_CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1 _CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1
_CN.LOFTR.FINE.ATTENTION = 'linear' _CN.LOFTR.FINE.ATTENTION = 'linear'
# 5. LoFTR Losses
# -- # coarse-level
_CN.LOFTR.LOSS = CN()
_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0
# _CN.LOFTR.LOSS.SPARSE_SPVS = False
# -- - -- # focal loss (coarse)
_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25
_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0
_CN.LOFTR.LOSS.POS_WEIGHT = 1.0
_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0
# _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not.
# use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE`
# -- # fine-level
_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0
_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
############## Dataset ############## ############## Dataset ##############
_CN.DATASET = CN() _CN.DATASET = CN()
@ -48,23 +68,27 @@ _CN.DATASET = CN()
# training and validating # training and validating
_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] _CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
_CN.DATASET.TRAIN_DATA_ROOT = None _CN.DATASET.TRAIN_DATA_ROOT = None
_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
_CN.DATASET.TRAIN_NPZ_ROOT = None _CN.DATASET.TRAIN_NPZ_ROOT = None
_CN.DATASET.TRAIN_LIST_PATH = None _CN.DATASET.TRAIN_LIST_PATH = None
_CN.DATASET.TRAIN_INTRINSIC_PATH = None _CN.DATASET.TRAIN_INTRINSIC_PATH = None
_CN.DATASET.VAL_DATA_ROOT = None _CN.DATASET.VAL_DATA_ROOT = None
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
_CN.DATASET.VAL_NPZ_ROOT = None _CN.DATASET.VAL_NPZ_ROOT = None
_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
_CN.DATASET.VAL_INTRINSIC_PATH = None _CN.DATASET.VAL_INTRINSIC_PATH = None
# testing # testing
_CN.DATASET.TEST_DATA_SOURCE = None _CN.DATASET.TEST_DATA_SOURCE = None
_CN.DATASET.TEST_DATA_ROOT = None _CN.DATASET.TEST_DATA_ROOT = None
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
_CN.DATASET.TEST_NPZ_ROOT = None _CN.DATASET.TEST_NPZ_ROOT = None
_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
_CN.DATASET.TEST_INTRINSIC_PATH = None _CN.DATASET.TEST_INTRINSIC_PATH = None
# 2. dataset config # 2. dataset config
# general options # general options
_CN.DATASET.MIN_OVERLAP_SCORE = 0.4 # discard data with overlap_score < min_overlap_score _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
# MegaDepth options # MegaDepth options
@ -75,10 +99,35 @@ _CN.DATASET.MGDPT_DF = 8
############## Trainer ############## ############## Trainer ##############
_CN.TRAINER = CN() _CN.TRAINER = CN()
_CN.TRAINER.CANONICAL_BS = 64
_CN.TRAINER.CANONICAL_LR = 6e-3
_CN.TRAINER.SCALING = None # this will be calculated automatically
_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
# optimizer
_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
_CN.TRAINER.ADAMW_DECAY = 0.1
# step-based warm-up
_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
_CN.TRAINER.WARMUP_RATIO = 0.
_CN.TRAINER.WARMUP_STEP = 4800
# learning rate scheduler
_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
_CN.TRAINER.MSLR_GAMMA = 0.5
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
# plotting related # plotting related
_CN.TRAINER.ENABLE_PLOTTING = True _CN.TRAINER.ENABLE_PLOTTING = True
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
# geometric metrics and pose solver # geometric metrics and pose solver
_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
@ -108,7 +157,7 @@ _CN.TRAINER.GRADIENT_CLIPPING = 0.5
# to be the same. When resume training from a checkpoint, it's better to use a different # to be the same. When resume training from a checkpoint, it's better to use a different
# seed, otherwise the sampled data will be exactly the same as before resuming, which will # seed, otherwise the sampled data will be exactly the same as before resuming, which will
# cause less unique data items sampled during the entire training. # cause less unique data items sampled during the entire training.
# Use of different seed value might affect the final training result, since not all data items # Use of different seed values might affect the final training result, since not all data items
# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) # are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
_CN.TRAINER.SEED = 66 _CN.TRAINER.SEED = 66

@ -22,6 +22,7 @@ class MegaDepthDataset(Dataset):
**kwargs): **kwargs):
""" """
Manage one scene(npz_path) of MegaDepth dataset. Manage one scene(npz_path) of MegaDepth dataset.
Args: Args:
root_dir (str): megadepth root directory that has `phoenix`. root_dir (str): megadepth root directory that has `phoenix`.
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
@ -69,12 +70,14 @@ class MegaDepthDataset(Dataset):
# read grayscale image and mask. (1, h, w) and (h, w) # read grayscale image and mask. (1, h, w) and (h, w)
img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
# TODO: Support augmentation & handle seeds for each worker correctly.
image0, mask0, scale0 = read_megadepth_gray( image0, mask0, scale0 = read_megadepth_gray(
img_name0, self.img_resize, self.df, self.img_padding, img_name0, self.img_resize, self.df, self.img_padding, None)
np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
image1, mask1, scale1 = read_megadepth_gray( image1, mask1, scale1 = read_megadepth_gray(
img_name1, self.img_resize, self.df, self.img_padding, img_name1, self.img_resize, self.df, self.img_padding, None)
np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
# read depth. shape: (h, w) # read depth. shape: (h, w)
if self.mode in ['train', 'val']: if self.mode in ['train', 'val']:

@ -0,0 +1,77 @@
import torch
from torch.utils.data import Sampler, ConcatDataset
class RandomConcatSampler(Sampler):
""" Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
Args:
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
repeat (int): repeatedly use the sampled indices multiple times for training.
[arXiv:1902.05509, arXiv:1901.09335]
NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
NOTE: This sampler behaves differently with DistributedSampler.
It assume the dataset is splitted across ranks instead of replicated.
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
"""
def __init__(self,
data_source: ConcatDataset,
n_samples_per_subset: int,
subset_replacement: bool=True,
shuffle: bool=True,
repeat: int=1,
seed: int=None):
if not isinstance(data_source, ConcatDataset):
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
self.data_source = data_source
self.n_subset = len(self.data_source.datasets)
self.n_samples_per_subset = n_samples_per_subset
self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
self.subset_replacement = subset_replacement
self.repeat = repeat
self.shuffle = shuffle
self.generator = torch.manual_seed(seed)
assert self.repeat >= 1
def __len__(self):
return self.n_samples
def __iter__(self):
indices = []
# sample from each sub-dataset
for d_idx in range(self.n_subset):
low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
high = self.data_source.cumulative_sizes[d_idx]
if self.subset_replacement:
rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
generator=self.generator, dtype=torch.int64)
else: # sample without replacement
len_subset = len(self.data_source.datasets[d_idx])
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
if len_subset >= self.n_samples_per_subset:
rand_tensor = rand_tensor[:self.n_samples_per_subset]
else: # padding with replacement
rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
generator=self.generator, dtype=torch.int64)
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
indices.append(rand_tensor)
indices = torch.cat(indices)
if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
rand_tensor = torch.randperm(len(indices), generator=self.generator)
indices = indices[rand_tensor]
# repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
if self.repeat > 1:
repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
if self.shuffle:
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
repeat_indices = map(_choice, repeat_indices)
indices = torch.cat([indices, *repeat_indices], 0)
assert indices.shape[0] == self.n_samples
return iter(indices.tolist())

@ -1,8 +1,17 @@
from os import path as osp from os import path as osp
from typing import Dict
from unicodedata import name
import numpy as np import numpy as np
import torch import torch
import torch.utils as utils import torch.utils as utils
from src.utils.dataset import read_scannet_gray, read_scannet_depth from numpy.linalg import inv
from src.utils.dataset import (
read_scannet_gray,
read_scannet_depth,
read_scannet_pose,
read_scannet_intrinsic
)
class ScanNetDataset(utils.data.Dataset): class ScanNetDataset(utils.data.Dataset):
@ -13,6 +22,7 @@ class ScanNetDataset(utils.data.Dataset):
mode='train', mode='train',
min_overlap_score=0.4, min_overlap_score=0.4,
augment_fn=None, augment_fn=None,
pose_dir=None,
**kwargs): **kwargs):
"""Manage one scene of ScanNet Dataset. """Manage one scene of ScanNet Dataset.
Args: Args:
@ -21,20 +31,20 @@ class ScanNetDataset(utils.data.Dataset):
intrinsic_path (str): path to depth-camera intrinsic file. intrinsic_path (str): path to depth-camera intrinsic file.
mode (str): options are ['train', 'val', 'test']. mode (str): options are ['train', 'val', 'test'].
augment_fn (callable, optional): augments images with pre-defined visual effects. augment_fn (callable, optional): augments images with pre-defined visual effects.
pose_dir (str): ScanNet root directory that contains all poses.
(we use a separate (optional) pose_dir since we store images and poses separately.)
""" """
super().__init__() super().__init__()
self.root_dir = root_dir self.root_dir = root_dir
self.pose_dir = pose_dir if pose_dir is not None else root_dir
self.mode = mode self.mode = mode
# prepare data_names, intrinsics and extrinsics(T) # prepare data_names, intrinsics and extrinsics(T)
with np.load(npz_path) as data: with np.load(npz_path) as data:
self.data_names = data['name'] self.data_names = data['name']
self.T_1to2s = data['rel_pose']
# min_overlap_score criterion
if 'score' in data.keys() and mode not in ['val' or 'test']: if 'score' in data.keys() and mode not in ['val' or 'test']:
kept_mask = data['score'] > min_overlap_score kept_mask = data['score'] > min_overlap_score
self.data_names = self.data_names[kept_mask] self.data_names = self.data_names[kept_mask]
self.T_1to2s = self.T_1to2s[kept_mask]
self.intrinsics = dict(np.load(intrinsic_path)) self.intrinsics = dict(np.load(intrinsic_path))
# for training LoFTR # for training LoFTR
@ -43,6 +53,18 @@ class ScanNetDataset(utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self.data_names) return len(self.data_names)
def _read_abs_pose(self, scene_name, name):
pth = osp.join(self.pose_dir,
scene_name,
'pose', f'{name}.txt')
return read_scannet_pose(pth)
def _compute_rel_pose(self, scene_name, name0, name1):
pose0 = self._read_abs_pose(scene_name, name0)
pose1 = self._read_abs_pose(scene_name, name1)
return np.matmul(pose1, inv(pose0)) # (4, 4)
def __getitem__(self, idx): def __getitem__(self, idx):
data_name = self.data_names[idx] data_name = self.data_names[idx]
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
@ -51,10 +73,12 @@ class ScanNetDataset(utils.data.Dataset):
# read the grayscale image which will be resized to (1, 480, 640) # read the grayscale image which will be resized to (1, 480, 640)
img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
image0 = read_scannet_gray(img_name0, resize=(640, 480),
augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # TODO: Support augmentation & handle seeds for each worker correctly.
image1 = read_scannet_gray(img_name1, resize=(640, 480), image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
# augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
# read the depthmap which is stored as (480, 640) # read the depthmap which is stored as (480, 640)
if self.mode in ['train', 'val']: if self.mode in ['train', 'val']:
@ -67,8 +91,8 @@ class ScanNetDataset(utils.data.Dataset):
K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
# read and compute relative poses # read and compute relative poses
T_0to1 = torch.tensor(self.T_1to2s[idx].copy(), dtype=torch.float).reshape(3, 4) T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
T_0to1 = torch.cat([T_0to1, torch.tensor([[0., 0., 0., 1.]])], dim=0).reshape(4, 4) dtype=torch.float32)
T_1to0 = T_0to1.inverse() T_1to0 = T_0to1.inverse()
data = { data = {
@ -80,7 +104,7 @@ class ScanNetDataset(utils.data.Dataset):
'T_1to0': T_1to0, 'T_1to0': T_1to0,
'K0': K_0, # (3, 3) 'K0': K_0, # (3, 3)
'K1': K_1, 'K1': K_1,
'dataset_name': 'scannet', 'dataset_name': 'ScanNet',
'scene_id': scene_name, 'scene_id': scene_name,
'pair_id': idx, 'pair_id': idx,
'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),

@ -1,15 +1,31 @@
import os
import math
from collections import abc
from loguru import logger from loguru import logger
from torch.utils.data.dataset import Dataset
from tqdm import tqdm from tqdm import tqdm
from os import path as osp from os import path as osp
from pathlib import Path
from joblib import Parallel, delayed
import pytorch_lightning as pl import pytorch_lightning as pl
from torch import distributed as dist from torch import distributed as dist
from torch.utils.data import DataLoader, ConcatDataset, DistributedSampler from torch.utils.data import (
Dataset,
DataLoader,
ConcatDataset,
DistributedSampler,
RandomSampler,
dataloader
)
from src.utils.augment import build_augmentor from src.utils.augment import build_augmentor
from src.utils.dataloader import get_local_split from src.utils.dataloader import get_local_split
from src.utils.misc import tqdm_joblib
from src.utils import comm
from src.datasets.megadepth import MegaDepthDataset from src.datasets.megadepth import MegaDepthDataset
from src.datasets.scannet import ScanNetDataset from src.datasets.scannet import ScanNetDataset
from src.datasets.sampler import RandomConcatSampler
class MultiSceneDataModule(pl.LightningDataModule): class MultiSceneDataModule(pl.LightningDataModule):
@ -17,7 +33,6 @@ class MultiSceneDataModule(pl.LightningDataModule):
For distributed training, each training process is assgined For distributed training, each training process is assgined
only a part of the training scenes to reduce memory overhead. only a part of the training scenes to reduce memory overhead.
""" """
def __init__(self, args, config): def __init__(self, args, config):
super().__init__() super().__init__()
@ -27,22 +42,26 @@ class MultiSceneDataModule(pl.LightningDataModule):
self.test_data_source = config.DATASET.TEST_DATA_SOURCE self.test_data_source = config.DATASET.TEST_DATA_SOURCE
# training and validating # training and validating
self.train_data_root = config.DATASET.TRAIN_DATA_ROOT self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
self.train_list_path = config.DATASET.TRAIN_LIST_PATH self.train_list_path = config.DATASET.TRAIN_LIST_PATH
self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
self.val_data_root = config.DATASET.VAL_DATA_ROOT self.val_data_root = config.DATASET.VAL_DATA_ROOT
self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
self.val_npz_root = config.DATASET.VAL_NPZ_ROOT self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
self.val_list_path = config.DATASET.VAL_LIST_PATH self.val_list_path = config.DATASET.VAL_LIST_PATH
self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
# testing # testing
self.test_data_root = config.DATASET.TEST_DATA_ROOT self.test_data_root = config.DATASET.TEST_DATA_ROOT
self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
self.test_list_path = config.DATASET.TEST_LIST_PATH self.test_list_path = config.DATASET.TEST_LIST_PATH
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
# 2. dataset config # 2. dataset config
# general options # general options
self.min_overlap_score = config.DATASET.MIN_OVERLAP_SCORE # 0.4, omit data with overlap_score < min_overlap_score self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
# MegaDepth options # MegaDepth options
@ -53,6 +72,17 @@ class MultiSceneDataModule(pl.LightningDataModule):
self.coarse_scale = 1 / config.LOFTR.RESOLUTION[0] # 0.125. for training loftr. self.coarse_scale = 1 / config.LOFTR.RESOLUTION[0] # 0.125. for training loftr.
# 3.loader parameters # 3.loader parameters
self.train_loader_params = {
'batch_size': args.batch_size,
'num_workers': args.num_workers,
'pin_memory': args.pin_memory,
}
self.val_loader_params = {
'batch_size': 1,
'shuffle': False,
'num_workers': args.num_workers,
'pin_memory': args.pin_memory,
}
self.test_loader_params = { self.test_loader_params = {
'batch_size': 1, 'batch_size': 1,
'shuffle': False, 'shuffle': False,
@ -60,6 +90,17 @@ class MultiSceneDataModule(pl.LightningDataModule):
'pin_memory': True 'pin_memory': True
} }
# 4. sampler
self.data_sampler = config.TRAINER.DATA_SAMPLER
self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
self.repeat = config.TRAINER.SB_REPEAT
# (optional) RandomSampler for debugging
# misc configurations
self.parallel_load_data = getattr(args, 'parallel_load_data', False)
self.seed = config.TRAINER.SEED # 66 self.seed = config.TRAINER.SEED # 66
def setup(self, stage=None): def setup(self, stage=None):
@ -69,7 +110,7 @@ class MultiSceneDataModule(pl.LightningDataModule):
stage (str): 'fit' in training phase, and 'test' in testing phase. stage (str): 'fit' in training phase, and 'test' in testing phase.
""" """
assert stage == 'test', "only support testing yet" assert stage in ['fit', 'test'], "stage must be either fit or test"
try: try:
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
@ -80,14 +121,58 @@ class MultiSceneDataModule(pl.LightningDataModule):
self.rank = 0 self.rank = 0
logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
self.test_dataset = self._setup_dataset(self.test_data_root, if stage == 'fit':
self.train_dataset = self._setup_dataset(
self.train_data_root,
self.train_npz_root,
self.train_list_path,
self.train_intrinsic_path,
mode='train',
min_overlap_score=self.min_overlap_score_train,
pose_dir=self.train_pose_root)
# setup multiple (optional) validation subsets
if isinstance(self.val_list_path, (list, tuple)):
self.val_dataset = []
if not isinstance(self.val_npz_root, (list, tuple)):
self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
self.val_dataset.append(self._setup_dataset(
self.val_data_root,
npz_root,
npz_list,
self.val_intrinsic_path,
mode='val',
min_overlap_score=self.min_overlap_score_test,
pose_dir=self.val_pose_root))
else:
self.val_dataset = self._setup_dataset(
self.val_data_root,
self.val_npz_root,
self.val_list_path,
self.val_intrinsic_path,
mode='val',
min_overlap_score=self.min_overlap_score_test,
pose_dir=self.val_pose_root)
logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
else: # stage == 'test
self.test_dataset = self._setup_dataset(
self.test_data_root,
self.test_npz_root, self.test_npz_root,
self.test_list_path, self.test_list_path,
self.test_intrinsic_path, self.test_intrinsic_path,
mode='test') mode='test',
min_overlap_score=self.min_overlap_score_test,
pose_dir=self.test_pose_root)
logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
def _setup_dataset(self, data_root, split_npz_root, scene_list_path, intri_path, mode='train'): def _setup_dataset(self,
data_root,
split_npz_root,
scene_list_path,
intri_path,
mode='train',
min_overlap_score=0.,
pose_dir=None):
""" Setup train / val / test set""" """ Setup train / val / test set"""
with open(scene_list_path, 'r') as f: with open(scene_list_path, 'r') as f:
npz_names = [name.split()[0] for name in f.readlines()] npz_names = [name.split()[0] for name in f.readlines()]
@ -98,13 +183,30 @@ class MultiSceneDataModule(pl.LightningDataModule):
local_npz_names = npz_names local_npz_names = npz_names
logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path, mode=mode) dataset_builder = self._build_concat_dataset_parallel \
if self.parallel_load_data \
else self._build_concat_dataset
return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
def _build_concat_dataset(self, data_root, npz_names, npz_dir, intrinsic_path, mode): def _build_concat_dataset(
self,
data_root,
npz_names,
npz_dir,
intrinsic_path,
mode,
min_overlap_score=0.,
pose_dir=None
):
datasets = [] datasets = []
augment_fn = self.augment_fn if mode == 'train' else None augment_fn = self.augment_fn if mode == 'train' else None
data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
for npz_name in tqdm(npz_names, desc=f'[rank:{self.rank}], loading {mode} datasets', disable=int(self.rank) != 0): if str(data_source).lower() == 'megadepth':
npz_names = [f'{n}.npz' for n in npz_names]
for npz_name in tqdm(npz_names,
desc=f'[rank:{self.rank}] loading {mode} datasets',
disable=int(self.rank) != 0):
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
npz_path = osp.join(npz_dir, npz_name) npz_path = osp.join(npz_dir, npz_name)
if data_source == 'ScanNet': if data_source == 'ScanNet':
@ -113,14 +215,15 @@ class MultiSceneDataModule(pl.LightningDataModule):
npz_path, npz_path,
intrinsic_path, intrinsic_path,
mode=mode, mode=mode,
min_overlap_score=self.min_overlap_score, min_overlap_score=min_overlap_score,
augment_fn=augment_fn)) augment_fn=augment_fn,
pose_dir=pose_dir))
elif data_source == 'MegaDepth': elif data_source == 'MegaDepth':
datasets.append( datasets.append(
MegaDepthDataset(data_root, MegaDepthDataset(data_root,
npz_path, npz_path,
mode=mode, mode=mode,
min_overlap_score=self.min_overlap_score, min_overlap_score=min_overlap_score,
img_resize=self.mgdpt_img_resize, img_resize=self.mgdpt_img_resize,
df=self.mgdpt_df, df=self.mgdpt_df,
img_padding=self.mgdpt_img_pad, img_padding=self.mgdpt_img_pad,
@ -131,7 +234,87 @@ class MultiSceneDataModule(pl.LightningDataModule):
raise NotImplementedError() raise NotImplementedError()
return ConcatDataset(datasets) return ConcatDataset(datasets)
def _build_concat_dataset_parallel(
self,
data_root,
npz_names,
npz_dir,
intrinsic_path,
mode,
min_overlap_score=0.,
pose_dir=None,
):
augment_fn = self.augment_fn if mode == 'train' else None
data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
if str(data_source).lower() == 'megadepth':
npz_names = [f'{n}.npz' for n in npz_names]
with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
total=len(npz_names), disable=int(self.rank) != 0)):
if data_source == 'ScanNet':
datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
delayed(lambda x: _build_dataset(
ScanNetDataset,
data_root,
osp.join(npz_dir, x),
intrinsic_path,
mode=mode,
min_overlap_score=min_overlap_score,
augment_fn=augment_fn,
pose_dir=pose_dir))(name)
for name in npz_names)
elif data_source == 'MegaDepth':
# TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
raise NotImplementedError()
datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
delayed(lambda x: _build_dataset(
MegaDepthDataset,
data_root,
osp.join(npz_dir, x),
mode=mode,
min_overlap_score=min_overlap_score,
img_resize=self.mgdpt_img_resize,
df=self.mgdpt_df,
img_padding=self.mgdpt_img_pad,
depth_padding=self.mgdpt_depth_pad,
augment_fn=augment_fn,
coarse_scale=self.coarse_scale))(name)
for name in npz_names)
else:
raise ValueError(f'Unknown dataset: {data_source}')
return ConcatDataset(datasets)
def train_dataloader(self):
""" Build training dataloader for ScanNet / MegaDepth. """
assert self.data_sampler in ['scene_balance']
logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
if self.data_sampler == 'scene_balance':
sampler = RandomConcatSampler(self.train_dataset,
self.n_samples_per_subset,
self.subset_replacement,
self.shuffle, self.repeat, self.seed)
else:
sampler = None
dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
return dataloader
def val_dataloader(self):
""" Build validation dataloader for ScanNet / MegaDepth. """
logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
if not isinstance(self.val_dataset, abc.Sequence):
sampler = DistributedSampler(self.val_dataset, shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
else:
dataloaders = []
for dataset in self.val_dataset:
sampler = DistributedSampler(dataset, shuffle=False)
dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
return dataloaders
def test_dataloader(self, *args, **kwargs): def test_dataloader(self, *args, **kwargs):
logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
sampler = DistributedSampler(self.test_dataset, shuffle=False) sampler = DistributedSampler(self.test_dataset, shuffle=False)
return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
def _build_dataset(dataset: Dataset, *args, **kwargs):
return dataset(*args, **kwargs)

@ -1,41 +1,97 @@
from collections import defaultdict
import pprint import pprint
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from matplotlib import pyplot as plt
from src.loftr import LoFTR from src.loftr import LoFTR
from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors, aggregate_metrics from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
from src.losses.loftr_loss import LoFTRLoss
from src.utils.comm import gather from src.optimizers import build_optimizer, build_scheduler
from src.utils.metrics import (
compute_symmetrical_epipolar_errors,
compute_pose_errors,
aggregate_metrics
)
from src.utils.plotting import make_matching_figures
from src.utils.comm import gather, all_gather
from src.utils.misc import lower_config, flattenList from src.utils.misc import lower_config, flattenList
from src.utils.profiler import PassThroughProfiler from src.utils.profiler import PassThroughProfiler
class PL_LoFTR(pl.LightningModule): class PL_LoFTR(pl.LightningModule):
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
"""
TODO:
- use the new version of PL logging API.
"""
super().__init__() super().__init__()
# Misc # Misc
self.config = config # full config self.config = config # full config
self.loftr_cfg = lower_config(self.config.LOFTR) _config = lower_config(self.config)
self.loftr_cfg = lower_config(_config['loftr'])
self.profiler = profiler or PassThroughProfiler() self.profiler = profiler or PassThroughProfiler()
self.dump_dir = dump_dir self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
# Matcher: LoFTR # Matcher: LoFTR
self.matcher = LoFTR(config=self.loftr_cfg) self.matcher = LoFTR(config=_config['loftr'])
self.loss = LoFTRLoss(_config)
# Pretrained weights # Pretrained weights
if pretrained_ckpt: if pretrained_ckpt:
self.matcher.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu')['state_dict']) self.matcher.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu')['state_dict'])
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
def test_step(self, batch, batch_idx): # Testing
self.dump_dir = dump_dir
def configure_optimizers(self):
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
optimizer = build_optimizer(self, self.config)
scheduler = build_scheduler(self.config, optimizer)
return [optimizer], [scheduler]
def optimizer_step(
self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# learning rate warm up
warmup_step = self.config.TRAINER.WARMUP_STEP
if self.trainer.global_step < warmup_step:
if self.config.TRAINER.WARMUP_TYPE == 'linear':
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
lr = base_lr + \
(self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
abs(self.config.TRAINER.TRUE_LR - base_lr)
for pg in optimizer.param_groups:
pg['lr'] = lr
elif self.config.TRAINER.WARMUP_TYPE == 'constant':
pass
else:
raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
# update params
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
def _trainval_inference(self, batch):
with self.profiler.profile("Compute coarse supervision"):
compute_supervision_coarse(batch, self.config)
with self.profiler.profile("LoFTR"): with self.profiler.profile("LoFTR"):
self.matcher(batch) self.matcher(batch)
with self.profiler.profile("Compute fine supervision"):
compute_supervision_fine(batch, self.config)
with self.profiler.profile("Compute losses"):
self.loss(batch)
def _compute_metrics(self, batch):
with self.profiler.profile("Copmute metrics"): with self.profiler.profile("Copmute metrics"):
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
@ -50,6 +106,106 @@ class PL_LoFTR(pl.LightningModule):
't_errs': batch['t_errs'], 't_errs': batch['t_errs'],
'inliers': batch['inliers']} 'inliers': batch['inliers']}
ret_dict = {'metrics': metrics} ret_dict = {'metrics': metrics}
return ret_dict, rel_pair_names
def training_step(self, batch, batch_idx):
self._trainval_inference(batch)
# logging
if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
# scalars
for k, v in batch['loss_scalars'].items():
self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
# net-params
if self.config.LOFTR.MATCH_COARSE.MATCH_TYPE == 'sinkhorn':
self.logger.experiment.add_scalar(
f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
# figures
if self.config.TRAINER.ENABLE_PLOTTING:
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
for k, v in figures.items():
self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
return {'loss': batch['loss']}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
if self.trainer.global_rank == 0:
self.logger.experiment.add_scalar(
'train/avg_loss_on_epoch', avg_loss,
global_step=self.current_epoch)
def validation_step(self, batch, batch_idx):
self._trainval_inference(batch)
ret_dict, _ = self._compute_metrics(batch)
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
figures = {self.config.TRAINER.PLOT_MODE: []}
if batch_idx % val_plot_interval == 0:
figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
return {
**ret_dict,
'loss_scalars': batch['loss_scalars'],
'figures': figures,
}
def validation_epoch_end(self, outputs):
# handle multiple validation sets
multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
multi_val_metrics = defaultdict(list)
for valset_idx, outputs in enumerate(multi_outputs):
# since pl performs sanity_check at the very begining of the training
cur_epoch = self.trainer.current_epoch
if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
cur_epoch = -1
# 1. loss_scalars: dict of list, on cpu
_loss_scalars = [o['loss_scalars'] for o in outputs]
loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
# 2. val metrics: dict of list, numpy
_metrics = [o['metrics'] for o in outputs]
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
for thr in [5, 10, 20]:
multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
# 3. figures
_figures = [o['figures'] for o in outputs]
figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
# tensorboard records only on rank 0
if self.trainer.global_rank == 0:
for k, v in loss_scalars.items():
mean_v = torch.stack(v).mean()
self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
for k, v in val_metrics_4tb.items():
self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
for k, v in figures.items():
if self.trainer.global_rank == 0:
for plot_idx, fig in enumerate(v):
self.logger.experiment.add_figure(
f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
plt.close('all')
for thr in [5, 10, 20]:
# log on all ranks for ModelCheckpoint callback to work properly
self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
def test_step(self, batch, batch_idx):
with self.profiler.profile("LoFTR"):
self.matcher(batch)
ret_dict, rel_pair_names = self._compute_metrics(batch)
with self.profiler.profile("dump_results"): with self.profiler.profile("dump_results"):
if self.dump_dir is not None: if self.dump_dir is not None:

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops.einops import rearrange from einops.einops import rearrange
INF = 1e9
def mask_border(m, b: int, v): def mask_border(m, b: int, v):
""" Mask borders with value """ Mask borders with value
@ -36,10 +37,23 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
m[b_idx, h0-bd:] = v m[b_idx, h0 - bd:] = v
m[b_idx, :, w0-bd:] = v m[b_idx, :, w0 - bd:] = v
m[b_idx, :, :, h1-bd:] = v m[b_idx, :, :, h1 - bd:] = v
m[b_idx, :, :, :, w1-bd:] = v m[b_idx, :, :, :, w1 - bd:] = v
def compute_max_candidates(p_m0, p_m1):
"""Compute the max candidates of all pairs within a batch
Args:
p_m0, p_m1 (torch.Tensor): padded masks
"""
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
max_cand = torch.sum(
torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
return max_cand
class CoarseMatching(nn.Module): class CoarseMatching(nn.Module):
@ -49,6 +63,9 @@ class CoarseMatching(nn.Module):
# general config # general config
self.thr = config['thr'] self.thr = config['thr']
self.border_rm = config['border_rm'] self.border_rm = config['border_rm']
# -- # for trainig fine-level LoFTR
self.train_coarse_percent = config['train_coarse_percent']
self.train_pad_num_gt_min = config['train_pad_num_gt_min']
# we provide 2 options for differentiable matching # we provide 2 options for differentiable matching
self.match_type = config['match_type'] self.match_type = config['match_type']
@ -60,7 +77,8 @@ class CoarseMatching(nn.Module):
except ImportError: except ImportError:
raise ImportError("download superglue.py first!") raise ImportError("download superglue.py first!")
self.log_optimal_transport = log_optimal_transport self.log_optimal_transport = log_optimal_transport
self.bin_score = nn.Parameter(torch.tensor(config['skh_init_bin_score'], requires_grad=True)) self.bin_score = nn.Parameter(
torch.tensor(config['skh_init_bin_score'], requires_grad=True))
self.skh_iters = config['skh_iters'] self.skh_iters = config['skh_iters']
self.skh_prefilter = config['skh_prefilter'] self.skh_prefilter = config['skh_prefilter']
else: else:
@ -88,26 +106,29 @@ class CoarseMatching(nn.Module):
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
# normalize # normalize
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, [feat_c0, feat_c1]) feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
[feat_c0, feat_c1])
if self.match_type == 'dual_softmax': if self.match_type == 'dual_softmax':
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) / self.temperature sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
feat_c1) / self.temperature
if mask_c0 is not None: if mask_c0 is not None:
valid_sim_mask = mask_c0[..., None] * mask_c1[:, None] sim_matrix.masked_fill_(
_inf = torch.zeros_like(sim_matrix) ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
_inf[~valid_sim_mask.bool()] = -1e9 -INF)
del valid_sim_mask
sim_matrix += _inf
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
elif self.match_type == 'sinkhorn': elif self.match_type == 'sinkhorn':
# sinkhorn, dustbin included # sinkhorn, dustbin included
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
if mask_c0 is not None: if mask_c0 is not None:
sim_matrix[:, :L, :S].masked_fill_(~(mask_c0[..., None] * mask_c1[:, None]).bool(), float('-inf')) sim_matrix[:, :L, :S].masked_fill_(
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-INF)
# build uniform prior & use sinkhorn # build uniform prior & use sinkhorn
log_assign_matrix = self.log_optimal_transport(sim_matrix, self.bin_score, self.skh_iters) log_assign_matrix = self.log_optimal_transport(
sim_matrix, self.bin_score, self.skh_iters)
assign_matrix = log_assign_matrix.exp() assign_matrix = log_assign_matrix.exp()
conf_matrix = assign_matrix[:, :-1, :-1] conf_matrix = assign_matrix[:, :-1, :-1]
@ -118,6 +139,9 @@ class CoarseMatching(nn.Module):
conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
if self.config['sparse_spvs']:
data.update({'conf_matrix_with_bin': assign_matrix.clone()})
data.update({'conf_matrix': conf_matrix}) data.update({'conf_matrix': conf_matrix})
# predict coarse matches from conf_matrix # predict coarse matches from conf_matrix
@ -140,16 +164,24 @@ class CoarseMatching(nn.Module):
'mkpts1_c' (torch.Tensor): [M, 2], 'mkpts1_c' (torch.Tensor): [M, 2],
'mconf' (torch.Tensor): [M]} 'mconf' (torch.Tensor): [M]}
""" """
axes_lengths = {'h0c': data['hw0_c'][0], 'w0c': data['hw0_c'][1], axes_lengths = {
'h1c': data['hw1_c'][0], 'w1c': data['hw1_c'][1]} 'h0c': data['hw0_c'][0],
'w0c': data['hw0_c'][1],
'h1c': data['hw1_c'][0],
'w1c': data['hw1_c'][1]
}
_device = conf_matrix.device
# 1. confidence thresholding # 1. confidence thresholding
mask = conf_matrix > self.thr mask = conf_matrix > self.thr
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', **axes_lengths) mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
**axes_lengths)
if 'mask0' not in data: if 'mask0' not in data:
mask_border(mask, self.border_rm, False) mask_border(mask, self.border_rm, False)
else: else:
mask_border_with_padding(mask, self.border_rm, False, data['mask0'], data['mask1']) mask_border_with_padding(mask, self.border_rm, False,
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', **axes_lengths) data['mask0'], data['mask1'])
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
**axes_lengths)
# 2. mutual nearest # 2. mutual nearest
mask = mask \ mask = mask \
@ -163,6 +195,45 @@ class CoarseMatching(nn.Module):
j_ids = all_j_ids[b_ids, i_ids] j_ids = all_j_ids[b_ids, i_ids]
mconf = conf_matrix[b_ids, i_ids, j_ids] mconf = conf_matrix[b_ids, i_ids, j_ids]
# 4. Random sampling of training samples for fine-level LoFTR
# (optional) pad samples with gt coarse-level matches
# NOTE:
# The sampling is performed across all pairs in a batch without manually balancing
# #samples for fine-level increases w.r.t. batch_size
if 'mask0' not in data:
num_candidates_max = mask.size(0) * max(
mask.size(1), mask.size(2))
else:
num_candidates_max = compute_max_candidates(
data['mask0'], data['mask1'])
num_matches_train = int(num_candidates_max *
self.train_coarse_percent)
num_matches_pred = len(b_ids)
assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
# pred_indices is to select from prediction
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
pred_indices = torch.arange(num_matches_pred, device=_device)
else:
pred_indices = torch.randint(
num_matches_pred,
(num_matches_train - self.train_pad_num_gt_min, ),
device=_device)
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
gt_pad_indices = torch.randint(
len(data['spv_b_ids']),
(max(num_matches_train - num_matches_pred,
self.train_pad_num_gt_min), ),
device=_device)
mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
b_ids, i_ids, j_ids, mconf = map(
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
dim=0),
*zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
[j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
# These matches select patches that feed into fine-level network # These matches select patches that feed into fine-level network
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
@ -170,14 +241,20 @@ class CoarseMatching(nn.Module):
scale = data['hw0_i'][0] / data['hw0_c'][0] scale = data['hw0_i'][0] / data['hw0_c'][0]
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
mkpts0_c = torch.stack([i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], dim=1) * scale0 mkpts0_c = torch.stack(
mkpts1_c = torch.stack([j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], dim=1) * scale1 [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
dim=1) * scale0
mkpts1_c = torch.stack(
[j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
dim=1) * scale1
# These matches is the current prediction (for visualization) # These matches is the current prediction (for visualization)
coarse_matches.update({'gt_mask': mconf == 0, coarse_matches.update({
'gt_mask': mconf == 0,
'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
'mkpts0_c': mkpts0_c[mconf != 0], 'mkpts0_c': mkpts0_c[mconf != 0],
'mkpts1_c': mkpts1_c[mconf != 0], 'mkpts1_c': mkpts1_c[mconf != 0],
'mconf': mconf[mconf != 0]}) 'mconf': mconf[mconf != 0]
})
return coarse_matches return coarse_matches

@ -53,6 +53,9 @@ class FineMatching(nn.Module):
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
# for fine-level supervision
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
# compute absolute kpt coords # compute absolute kpt coords
self.get_fine_match(coords_normalized, data) self.get_fine_match(coords_normalized, data)

@ -0,0 +1,54 @@
import torch
@torch.no_grad()
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
""" Warp kpts0 from I0 to I1 with depth, K and Rt
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
depth0 (torch.Tensor): [N, H, W],
depth1 (torch.Tensor): [N, H, W],
T_0to1 (torch.Tensor): [N, 3, 4],
K0 (torch.Tensor): [N, 3, 3],
K1 (torch.Tensor): [N, 3, 3],
Returns:
calculable_mask (torch.Tensor): [N, L]
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
"""
kpts0_long = kpts0.round().long()
# Sample depth, get calculable_mask on depth != 0
kpts0_depth = torch.stack(
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
) # (N, L)
nonzero_mask = kpts0_depth != 0
# Unproject
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
# Rigid Transform
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
# Covisible Check
h, w = depth1.shape[1:3]
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
w_kpts0_long = w_kpts0.long()
w_kpts0_long[~covisible_mask, :] = 0
w_kpts0_depth = torch.stack(
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
) # (N, L)
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
valid_mask = nonzero_mask * covisible_mask * consistent_mask
return valid_mask, w_kpts0

@ -0,0 +1,151 @@
from math import log
from loguru import logger
import torch
from einops import repeat
from kornia.utils import create_meshgrid
from .geometry import warp_kpts
############## ↓ Coarse-Level supervision ↓ ##############
@torch.no_grad()
def mask_pts_at_padded_regions(grid_pt, mask):
"""For megadepth dataset, zero-padding exists in images"""
mask = repeat(mask, 'n h w -> n (h w) c', c=2)
grid_pt[~mask.bool()] = 0
return grid_pt
@torch.no_grad()
def spvs_coarse(data, config):
"""
Update:
data (dict): {
"conf_matrix_gt": [N, hw0, hw1],
'spv_b_ids': [M]
'spv_i_ids': [M]
'spv_j_ids': [M]
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
'spv_pt1_i': [N, hw1, 2], in original image resolution
}
NOTE:
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
"""
# 1. misc
device = data['image0'].device
N, _, H0, W0 = data['image0'].shape
_, _, H1, W1 = data['image1'].shape
scale = config['LOFTR']['RESOLUTION'][0]
scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
# 2. warp grids
# create kpts in meshgrid and resize them to image resolution
grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
grid_pt0_i = scale0 * grid_pt0_c
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
grid_pt1_i = scale1 * grid_pt1_c
# mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
if 'mask0' in data:
grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
# warp kpts bi-directionally and resize them to coarse-level resolution
# (no depth consistency check, since it leads to worse results experimentally)
# (unhandled edge case: points with 0-depth will be warped to the left-up corner)
_, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
w_pt0_c = w_pt0_i / scale1
w_pt1_c = w_pt1_i / scale0
# 3. check if mutual nearest neighbor
w_pt0_c_round = w_pt0_c[:, :, :].round().long()
nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
w_pt1_c_round = w_pt1_c[:, :, :].round().long()
nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
# corner case: out of boundary
def out_bound_mask(pt, w, h):
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
correct_0to1[:, 0] = False # ignore the top-left corner
# 4. construct a gt conf_matrix
conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
b_ids, i_ids = torch.where(correct_0to1 != 0)
j_ids = nearest_index1[b_ids, i_ids]
conf_matrix_gt[b_ids, i_ids, j_ids] = 1
data.update({'conf_matrix_gt': conf_matrix_gt})
# 5. save coarse matches(gt) for training fine level
if len(b_ids) == 0:
logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
# this won't affect fine-level loss calculation
b_ids = torch.tensor([0], device=device)
i_ids = torch.tensor([0], device=device)
j_ids = torch.tensor([0], device=device)
data.update({
'spv_b_ids': b_ids,
'spv_i_ids': i_ids,
'spv_j_ids': j_ids
})
# 6. save intermediate results (for fast fine-level computation)
data.update({
'spv_w_pt0_i': w_pt0_i,
'spv_pt1_i': grid_pt1_i
})
def compute_supervision_coarse(data, config):
assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
data_source = data['dataset_name'][0]
if data_source.lower() in ['scannet', 'megadepth']:
spvs_coarse(data, config)
else:
raise ValueError(f'Unknown data source: {data_source}')
############## ↓ Fine-Level supervision ↓ ##############
@torch.no_grad()
def spvs_fine(data, config):
"""
Update:
data (dict):{
"expec_f_gt": [M, 2]}
"""
# 1. misc
# w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
scale = config['LOFTR']['RESOLUTION'][1]
radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
# 2. get coarse prediction
b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
# 3. compute gt
scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
# `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
data.update({"expec_f_gt": expec_f_gt})
def compute_supervision_fine(data, config):
data_source = data['dataset_name'][0]
if data_source.lower() in ['scannet', 'megadepth']:
spvs_fine(data, config)
else:
raise NotImplementedError

@ -0,0 +1,192 @@
from loguru import logger
import torch
import torch.nn as nn
class LoFTRLoss(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config # config under the global namespace
self.loss_config = config['loftr']['loss']
self.match_type = self.config['loftr']['match_coarse']['match_type']
self.sparse_spvs = self.config['loftr']['match_coarse']['sparse_spvs']
# coarse-level
self.correct_thr = self.loss_config['fine_correct_thr']
self.c_pos_w = self.loss_config['pos_weight']
self.c_neg_w = self.loss_config['neg_weight']
# fine-level
self.fine_type = self.loss_config['fine_type']
def compute_coarse_loss(self, conf, conf_gt, weight=None):
""" Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
Args:
conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
conf_gt (torch.Tensor): (N, HW0, HW1)
weight (torch.Tensor): (N, HW0, HW1)
"""
pos_mask, neg_mask = conf_gt == 1, conf_gt == 0
c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w
# corner case: no gt coarse-level match at all
if not pos_mask.any(): # assign a wrong gt
pos_mask[0, 0, 0] = True
if weight is not None:
weight[0, 0, 0] = 0.
c_pos_w = 0.
if not neg_mask.any():
neg_mask[0, 0, 0] = True
if weight is not None:
weight[0, 0, 0] = 0.
c_neg_w = 0.
if self.loss_config['coarse_type'] == 'cross_entropy':
assert not self.sparse_spvs, 'Sparse Supervision for cross-entropy not implemented!'
conf = torch.clamp(conf, 1e-6, 1-1e-6)
loss_pos = - torch.log(conf[pos_mask])
loss_neg = - torch.log(1 - conf[neg_mask])
if weight is not None:
loss_pos = loss_pos * weight[pos_mask]
loss_neg = loss_neg * weight[neg_mask]
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
elif self.loss_config['coarse_type'] == 'focal':
conf = torch.clamp(conf, 1e-6, 1-1e-6)
alpha = self.loss_config['focal_alpha']
gamma = self.loss_config['focal_gamma']
if self.sparse_spvs:
pos_conf = conf[:, :-1, :-1][pos_mask] \
if self.match_type == 'sinkhorn' \
else conf[pos_mask]
loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
# calculate losses for negative samples
if self.match_type == 'sinkhorn':
neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
neg_conf = torch.cat([conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0)
loss_neg = - alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
else:
# These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
# we could also add 'pseudo negtive-samples'
pass
# handle loss weights
if weight is not None:
# Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
# but only through manually setting corresponding regions in sim_matrix to '-inf'.
loss_pos = loss_pos * weight[pos_mask]
if self.match_type == 'sinkhorn':
neg_w0 = (weight.sum(-1) != 0)[neg0]
neg_w1 = (weight.sum(1) != 0)[neg1]
neg_mask = torch.cat([neg_w0, neg_w1], 0)
loss_neg = loss_neg[neg_mask]
loss = c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() \
if self.match_type == 'sinkhorn' \
else c_pos_w * loss_pos.mean()
return loss
# positive and negative elements occupy similar propotions. => more balanced loss weights needed
else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
if weight is not None:
loss_pos = loss_pos * weight[pos_mask]
loss_neg = loss_neg * weight[neg_mask]
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
# each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
else:
raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type']))
def compute_fine_loss(self, expec_f, expec_f_gt):
if self.fine_type == 'l2_with_std':
return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
elif self.fine_type == 'l2':
return self._compute_fine_loss_l2(expec_f, expec_f_gt)
else:
raise NotImplementedError()
def _compute_fine_loss_l2(self, expec_f, expec_f_gt):
"""
Args:
expec_f (torch.Tensor): [M, 2] <x, y>
expec_f_gt (torch.Tensor): [M, 2] <x, y>
"""
correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
if correct_mask.sum() == 0:
if self.training: # this seldomly happen when training, since we pad prediction with gt
logger.warning("assign a false supervision to avoid ddp deadlock")
correct_mask[0] = True
else:
return None
offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1)
return offset_l2.mean()
def _compute_fine_loss_l2_std(self, expec_f, expec_f_gt):
"""
Args:
expec_f (torch.Tensor): [M, 3] <x, y, std>
expec_f_gt (torch.Tensor): [M, 2] <x, y>
"""
# correct_mask tells you which pair to compute fine-loss
correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
# use std as weight that measures uncertainty
std = expec_f[:, 2]
inverse_std = 1. / torch.clamp(std, min=1e-10)
weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std
# corner case: no correct coarse match found
if not correct_mask.any():
if self.training: # this seldomly happen during training, since we pad prediction with gt
# sometimes there is not coarse-level gt at all.
logger.warning("assign a false supervision to avoid ddp deadlock")
correct_mask[0] = True
weight[0] = 0.
else:
return None
# l2 loss with std
offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1)
loss = (offset_l2 * weight[correct_mask]).mean()
return loss
@torch.no_grad()
def compute_c_weight(self, data):
""" compute element-wise weights for computing coarse-level loss. """
if 'mask0' in data:
c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
else:
c_weight = None
return c_weight
def forward(self, data):
"""
Update:
data (dict): update{
'loss': [1] the reduced loss across a batch,
'loss_scalars' (dict): loss scalars for tensorboard_record
}
"""
loss_scalars = {}
# 0. compute element-wise loss weight
c_weight = self.compute_c_weight(data)
# 1. coarse-level loss
loss_c = self.compute_coarse_loss(
data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \
else data['conf_matrix'],
data['conf_matrix_gt'],
weight=c_weight)
loss = loss_c * self.loss_config['coarse_weight']
loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
# 2. fine-level loss
loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt'])
if loss_f is not None:
loss += loss_f * self.loss_config['fine_weight']
loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
else:
assert self.training is False
loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound
loss_scalars.update({'loss': loss.clone().detach().cpu()})
data.update({"loss": loss, "loss_scalars": loss_scalars})

@ -0,0 +1,42 @@
import torch
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
def build_optimizer(model, config):
name = config.TRAINER.OPTIMIZER
lr = config.TRAINER.TRUE_LR
if name == "adam":
return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
elif name == "adamw":
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
else:
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
def build_scheduler(config, optimizer):
"""
Returns:
scheduler (dict):{
'scheduler': lr_scheduler,
'interval': 'step', # or 'epoch'
'monitor': 'val_f1', (optional)
'frequency': x, (optional)
}
"""
scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
name = config.TRAINER.SCHEDULER
if name == 'MultiStepLR':
scheduler.update(
{'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
elif name == 'CosineAnnealing':
scheduler.update(
{'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
elif name == 'ExponentialLR':
scheduler.update(
{'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
else:
raise NotImplementedError()
return scheduler

@ -39,6 +39,8 @@ class MobileAug(object):
def build_augmentor(method=None, **kwargs): def build_augmentor(method=None, **kwargs):
if method is not None:
raise NotImplementedError('Using of augmentation functions are not supported yet!')
if method == 'dark': if method == 'dark':
return DarkAug() return DarkAug()
elif method == 'mobile': elif method == 'mobile':

@ -4,13 +4,14 @@ import numpy as np
# --- PL-DATAMODULE --- # --- PL-DATAMODULE ---
def get_local_split(items: list, world_size: int, rank: int, seed: int): def get_local_split(items: list, world_size: int, rank: int, seed: int):
""" The local rank only loads a split of dataset. """ """ The local rank only loads a split of the dataset. """
n_items = len(items) n_items = len(items)
items_permute = np.random.RandomState(seed).permutation(items) items_permute = np.random.RandomState(seed).permutation(items)
if n_items % world_size == 0: if n_items % world_size == 0:
padded_items = items_permute padded_items = items_permute
else: else:
padding = np.random.RandomState(seed).choice(items, padding = np.random.RandomState(seed).choice(
items,
world_size - (n_items % world_size), world_size - (n_items % world_size),
replace=True) replace=True)
padded_items = np.concatenate([items_permute, padding]) padded_items = np.concatenate([items_permute, padding])

@ -1,15 +1,46 @@
import io
from loguru import logger
import cv2 import cv2
import numpy as np import numpy as np
import h5py import h5py
import torch import torch
from numpy.linalg import inv
MEGADEPTH_CLIENT = SCANNET_CLIENT = None
# --- DATA IO --- # --- DATA IO ---
def imread_gray(path, augment_fn=None): def load_array_from_s3(
if augment_fn is None: path, client, cv_type,
image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) use_h5py=False,
):
byte_str = client.Get(path)
try:
if not use_h5py:
raw_array = np.fromstring(byte_str, np.uint8)
data = cv2.imdecode(raw_array, cv_type)
else: else:
f = io.BytesIO(byte_str)
data = np.array(h5py.File(f, 'r')['/depth'])
except Exception as ex:
print(f"==> Data loading failure: {path}")
raise ex
assert data is not None
return data
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
else cv2.IMREAD_COLOR
if str(path).startswith('s3://'):
image = load_array_from_s3(str(path), client, cv_type)
else:
image = cv2.imread(str(path), cv_type)
if augment_fn is not None:
image = cv2.imread(str(path), cv2.IMREAD_COLOR) image = cv2.imread(str(path), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = augment_fn(image) image = augment_fn(image)
@ -68,7 +99,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
scale (torch.tensor): [w/w_new, h/h_new] scale (torch.tensor): [w/w_new, h/h_new]
""" """
# read image # read image
image = imread_gray(path, augment_fn) image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
# resize image # resize image
w, h = image.shape[1], image.shape[0] w, h = image.shape[1], image.shape[0]
@ -91,6 +122,9 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
def read_megadepth_depth(path, pad_to=None): def read_megadepth_depth(path, pad_to=None):
if str(path).startswith('s3://'):
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
else:
depth = np.array(h5py.File(path, 'r')['depth']) depth = np.array(h5py.File(path, 'r')['depth'])
if pad_to is not None: if pad_to is not None:
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
@ -120,6 +154,28 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
def read_scannet_depth(path): def read_scannet_depth(path):
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) / 1000 if str(path).startswith('s3://'):
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
else:
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
depth = depth / 1000
depth = torch.from_numpy(depth).float() # (h, w) depth = torch.from_numpy(depth).float() # (h, w)
return depth return depth
def read_scannet_pose(path):
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
Returns:
pose_w2c (np.ndarray): (4, 4)
"""
cam2world = np.loadtxt(path, delimiter=' ')
world2cam = inv(cam2world)
return world2cam
def read_scannet_intrinsic(path):
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
"""
intrinsic = np.loadtxt(path, delimiter=' ')
return intrinsic[:-1, :-1]

@ -1,7 +1,14 @@
from loguru import logger import os
from yacs.config import CfgNode as CN import contextlib
import joblib
from typing import Union
from loguru import _Logger, logger
from itertools import chain from itertools import chain
import torch
from yacs.config import CfgNode as CN
from pytorch_lightning.utilities import rank_zero_only
def lower_config(yacs_cfg): def lower_config(yacs_cfg):
if not isinstance(yacs_cfg, CN): if not isinstance(yacs_cfg, CN):
@ -21,21 +28,74 @@ def log_on(condition, message, level):
logger.log(level, message) logger.log(level, message)
def get_rank_zero_only_logger(logger: _Logger):
if rank_zero_only.rank == 0:
return logger
else:
for _level in logger._core.levels.keys():
level = _level.lower()
setattr(logger, level,
lambda x: None)
logger._log = lambda x: None
return logger
def setup_gpus(gpus: Union[str, int]) -> int:
""" A temporary fix for pytorch-lighting 1.3.x """
gpus = str(gpus)
gpu_ids = []
if ',' not in gpus:
n_gpus = int(gpus)
return n_gpus if n_gpus != -1 else torch.cuda.device_count()
else:
gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
# setup environment variables
visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
if visible_devices is None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
else:
logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
return len(gpu_ids)
def flattenList(x): def flattenList(x):
return list(chain(*x)) return list(chain(*x))
if __name__ == '__main__': @contextlib.contextmanager
_CN = CN() def tqdm_joblib(tqdm_object):
_CN.A = CN() """Context manager to patch joblib to report into tqdm progress bar given as argument
_CN.A.AA = CN()
_CN.A.AA.AAA = CN() Usage:
_CN.A.AA.AAA.AAAA = "AAAAA" with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
ret_vals = Parallel(n_jobs=args.world_size)(
delayed(lambda x: _compute_cov_score(pid, *x))(param)
for param in tqdm(combinations(image_ids, 2),
desc=f'Computing cov_score of [{pid}]',
total=len(image_ids)*(len(image_ids)-1)/2))
Src: https://stackoverflow.com/a/58936697
"""
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)
_CN.B = CN() old_batch_callback = joblib.parallel.BatchCompletionCallBack
_CN.B.BB = CN() joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
_CN.B.BB.BBB = CN() try:
_CN.B.BB.BBB.BBBB = "BBBBB" yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()
print(lower_config(_CN))
print(lower_config(_CN.A))

@ -1,13 +1,32 @@
import bisect
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
# --- VISUALIZATION --- def _compute_conf_thresh(data):
dataset_name = data['dataset_name'][0].lower()
if dataset_name == 'scannet':
thr = 5e-4
elif dataset_name == 'megadepth':
thr = 1e-4
else:
raise ValueError(f'Unknown dataset: {dataset_name}')
return thr
# --- VISUALIZATION --- #
def plot_keypoints(axes, kpts0, kpts1, color='w', ps=2):
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
def make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=[], path=None):
def make_matching_figure(
img0, img1, mkpts0, mkpts1, color,
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
# draw image pair # draw image pair
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=75) assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
axes[0].imshow(img0, cmap='gray') axes[0].imshow(img0, cmap='gray')
axes[1].imshow(img1, cmap='gray') axes[1].imshow(img1, cmap='gray')
for i in range(2): # clear all frames for i in range(2): # clear all frames
@ -17,13 +36,21 @@ def make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=[], path=None):
spine.set_visible(False) spine.set_visible(False)
plt.tight_layout(pad=1) plt.tight_layout(pad=1)
if kpts0 is not None:
assert kpts1 is not None
# plot_keypoints(axes, kpts0, kpts1, color='k', ps=4)
plot_keypoints(axes, kpts0, kpts1, color='w', ps=2)
# draw matches # draw matches
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
fig.canvas.draw() fig.canvas.draw()
transFigure = fig.transFigure.inverted() transFigure = fig.transFigure.inverted()
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
transform=fig.transFigure, c=color[i], linewidth=1) for i in range(len(mkpts0))] (fkpts0[i, 1], fkpts1[i, 1]),
transform=fig.transFigure, c=color[i], linewidth=1)
for i in range(len(mkpts0))]
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
@ -42,6 +69,91 @@ def make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=[], path=None):
return fig return fig
def _make_evaluation_figure(data, b_id, alpha='dynamic'):
b_mask = data['m_bids'] == b_id
conf_thr = _compute_conf_thresh(data)
img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
# for megadepth, we visualize matches on the resized image
if 'scale0' in data:
kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
epi_errs = data['epi_errs'][b_mask].cpu().numpy()
correct_mask = epi_errs < conf_thr
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
n_correct = np.sum(correct_mask)
n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
# recall might be larger than 1, since the calculation of conf_matrix_gt
# uses groundtruth depths and camera poses, but epipolar distance is used here.
# matching info
if alpha == 'dynamic':
alpha = dynamic_alpha(len(correct_mask))
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
text = [
f'Matches {len(kpts0)}',
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
]
# make the figure
figure = make_matching_figure(img0, img1, kpts0, kpts1,
color, text=text)
return figure
def _make_confidence_figure(data, b_id):
# TODO: Implement confidence figure
raise NotImplementedError()
def make_matching_figures(data, config, mode='evaluation'):
""" Make matching figures for a batch.
Args:
data (Dict): a batch updated by PL_LoFTR.
config (Dict): matcher config
Returns:
figures (Dict[str, List[plt.figure]]
TODO:
- confidence mode plotting
- parallel plotting
- evaluation mode & confidence mode at the same time
"""
assert mode in ['evaluation', 'confidence'] # 'confidence'
figures = {mode: []}
for b_id in range(data['image0'].size(0)):
if mode == 'evaluation':
fig = _make_evaluation_figure(
data, b_id,
alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
elif mode == 'confidence':
fig = _make_confidence_figure(data, b_id)
else:
raise ValueError(f'Unknown plot mode: {mode}')
figures[mode].append(fig)
return figures
def dynamic_alpha(n_matches,
milestones=[0, 300, 1000, 2000],
alphas=[1.0, 0.8, 0.4, 0.2]):
if n_matches == 0:
return 1.0
ranges = list(zip(alphas, alphas[1:] + [None]))
loc = bisect.bisect_right(milestones, n_matches) - 1
_range = ranges[loc]
if _range[1] is None:
return _range[0]
return _range[1] + (milestones[loc + 1] - n_matches) / (
milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
def error_colormap(err, thr, alpha=1.0): def error_colormap(err, thr, alpha=1.0):
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
x = 1 - np.clip(err / (thr * 2), 0, 1) x = 1 - np.clip(err / (thr * 2), 0, 1)

@ -32,7 +32,6 @@ def build_profiler(name):
return InferenceProfiler() return InferenceProfiler()
elif name == 'pytorch': elif name == 'pytorch':
from pytorch_lightning.profiler import PyTorchProfiler from pytorch_lightning.profiler import PyTorchProfiler
# TODO: this profiler will be introduced after upgrading pl dependency to 1.3.0 @zehong
return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
elif name is None: elif name is None:
return PassThroughProfiler() return PassThroughProfiler()

@ -0,0 +1 @@
Subproject commit c0626d58c843ee0464b0fa1dd4de4059bfae0ab4

@ -0,0 +1,120 @@
import math
import argparse
import pprint
from distutils.util import strtobool
from pathlib import Path
from loguru import logger as loguru_logger
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin
from src.config.default import get_cfg_defaults
from src.utils.misc import get_rank_zero_only_logger, setup_gpus
from src.utils.profiler import build_profiler
from src.lightning.data import MultiSceneDataModule
from src.lightning.lightning_loftr import PL_LoFTR
loguru_logger = get_rank_zero_only_logger(loguru_logger)
def parse_args():
# init a costum parser which will be added into pl.Trainer parser
# check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
parser.add_argument(
parser.add_argument(
'--exp_name', type=str, default='default_exp_name')
parser.add_argument(
'--batch_size', type=int, default=4, help='batch_size per gpu')
parser.add_argument(
'--num_workers', type=int, default=4)
parser.add_argument(
'--pin_memory', type=lambda x: bool(strtobool(x)),
nargs='?', default=True, help='whether loading data to pinned memory or not')
parser.add_argument(
'--ckpt_path', type=str, default=None,
parser.add_argument(
'--disable_ckpt', action='store_true',
help='disable checkpoint saving (useful for debugging).')
parser.add_argument(
'--profiler_name', type=str, default=None,
help='options: [inference, pytorch], or leave it unset')
parser.add_argument(
'--parallel_load_data', action='store_true',
help='load datasets in with multiple processes.')
parser = pl.Trainer.add_argparse_args(parser)
return parser.parse_args()
def main():
# parse arguments
args = parse_args()
rank_zero_only(pprint.pprint)(vars(args))
# init default-cfg and merge it with the main- and data-cfg
config = get_cfg_defaults()
config.merge_from_file(args.main_cfg_path)
config.merge_from_file(args.data_cfg_path)
pl.seed_everything(config.TRAINER.SEED) # reproducibility
# TODO: Use different seeds for each dataloader workers
# This is needed for data augmentation
# scale lr and warmup-step automatically
args.gpus = _n_gpus = setup_gpus(args.gpus)
config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
_scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
config.TRAINER.SCALING = _scaling
config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
# lightning module
profiler = build_profiler(args.profiler_name)
model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
loguru_logger.info(f"LoFTR LightningModule initialized!")
# lightning data
data_module = MultiSceneDataModule(args, config)
loguru_logger.info(f"LoFTR DataModule initialized!")
# TensorBoard Logger
logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
ckpt_dir = Path(logger.log_dir) / 'checkpoints'
# Callbacks
# TODO: update ModelCheckpoint to monitor multiple metrics
ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
save_last=True,
dirpath=str(ckpt_dir),
filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks = [lr_monitor]
if not args.disable_ckpt:
callbacks.append(ckpt_callback)
# Lightning Trainer
trainer = pl.Trainer.from_argparse_args(
args,
plugins=DDPPlugin(find_unused_parameters=False,
num_nodes=args.num_nodes,
sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
callbacks=callbacks,
logger=logger,
sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
replace_sampler_ddp=False, # use custom sampler
reload_dataloaders_every_epoch=False, # avoid repeated samples!
weights_summary='full',
profiler=profiler)
loguru_logger.info(f"Trainer initialized!")
loguru_logger.info(f"Start training!")
trainer.fit(model, datamodule=data_module)
if __name__ == '__main__':
main()
Loading…
Cancel
Save