|
|
|
@ -689,8 +689,15 @@ class YOLACTProtonet(BaseModule): |
|
|
|
|
prototypes = prototypes.permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
|
|
|
|
num_imgs = x.size(0) |
|
|
|
|
# Training state |
|
|
|
|
if self.training: |
|
|
|
|
|
|
|
|
|
# The reason for not using self.training is that |
|
|
|
|
# val workflow will have a dimension mismatch error. |
|
|
|
|
# Note that this writing method is very tricky. |
|
|
|
|
# Fix https://github.com/open-mmlab/mmdetection/issues/5978 |
|
|
|
|
is_train_or_val_workflow = (coeff_pred[0].dim() == 4) |
|
|
|
|
|
|
|
|
|
# Train or val workflow |
|
|
|
|
if is_train_or_val_workflow: |
|
|
|
|
coeff_pred_list = [] |
|
|
|
|
for coeff_pred_per_level in coeff_pred: |
|
|
|
|
coeff_pred_per_level = \ |
|
|
|
@ -707,7 +714,7 @@ class YOLACTProtonet(BaseModule): |
|
|
|
|
cur_img_meta = img_meta[idx] |
|
|
|
|
|
|
|
|
|
# Testing state |
|
|
|
|
if not self.training: |
|
|
|
|
if not is_train_or_val_workflow: |
|
|
|
|
bboxes_for_cropping = cur_bboxes |
|
|
|
|
else: |
|
|
|
|
cur_sampling_results = sampling_results[idx] |
|
|
|
|