|
|
@ -270,7 +270,7 @@ class RTDETRDecoder(nn.Module): |
|
|
|
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) |
|
|
|
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) |
|
|
|
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 |
|
|
|
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 |
|
|
|
anchors = torch.log(anchors / (1 - anchors)) |
|
|
|
anchors = torch.log(anchors / (1 - anchors)) |
|
|
|
anchors = torch.where(valid_mask, anchors, torch.inf) |
|
|
|
anchors = anchors.masked_fill(~valid_mask, float('inf')) |
|
|
|
return anchors, valid_mask |
|
|
|
return anchors, valid_mask |
|
|
|
|
|
|
|
|
|
|
|
def _get_encoder_input(self, x): |
|
|
|
def _get_encoder_input(self, x): |
|
|
@ -294,7 +294,7 @@ class RTDETRDecoder(nn.Module): |
|
|
|
bs = len(feats) |
|
|
|
bs = len(feats) |
|
|
|
# prepare input for decoder |
|
|
|
# prepare input for decoder |
|
|
|
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) |
|
|
|
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) |
|
|
|
features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256 |
|
|
|
features = self.enc_output(valid_mask * feats) # bs, h*w, 256 |
|
|
|
|
|
|
|
|
|
|
|
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) |
|
|
|
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) |
|
|
|
# dynamic anchors + static content |
|
|
|
# dynamic anchors + static content |
|
|
|