|
|
@ -1,18 +1,9 @@ |
|
|
|
# coding=utf-8 |
|
|
|
# ------------------------------------------------------------------------ |
|
|
|
# Copyright 2022 The IDEA Authors. All rights reserved. |
|
|
|
# Grounding DINO |
|
|
|
# |
|
|
|
# url: https://github.com/IDEA-Research/GroundingDINO |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# Copyright (c) 2023 IDEA. All Rights Reserved. |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# ------------------------------------------------------------------------ |
|
|
|
# |
|
|
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
|
|
# |
|
|
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------------ |
|
|
|
|
|
|
|
# Deformable DETR |
|
|
|
# Deformable DETR |
|
|
|
# Copyright (c) 2020 SenseTime. All Rights Reserved. |
|
|
|
# Copyright (c) 2020 SenseTime. All Rights Reserved. |
|
|
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] |
|
|
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] |
|
|
@ -26,12 +17,14 @@ |
|
|
|
import math |
|
|
|
import math |
|
|
|
import warnings |
|
|
|
import warnings |
|
|
|
from typing import Optional |
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch.autograd import Function |
|
|
|
from torch.autograd import Function |
|
|
|
from torch.autograd.function import once_differentiable |
|
|
|
from torch.autograd.function import once_differentiable |
|
|
|
from torch.nn.init import constant_, xavier_uniform_ |
|
|
|
from torch.nn.init import constant_, xavier_uniform_ |
|
|
|
|
|
|
|
|
|
|
|
from groundingdino import _C |
|
|
|
from groundingdino import _C |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -290,7 +283,6 @@ class MultiScaleDeformableAttention(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value |
|
|
|
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
value = self.value_proj(value) |
|
|
|
value = self.value_proj(value) |
|
|
|
if key_padding_mask is not None: |
|
|
|
if key_padding_mask is not None: |
|
|
|
value = value.masked_fill(key_padding_mask[..., None], float(0)) |
|
|
|
value = value.masked_fill(key_padding_mask[..., None], float(0)) |
|
|
@ -339,7 +331,6 @@ class MultiScaleDeformableAttention(nn.Module): |
|
|
|
sampling_locations = sampling_locations.float() |
|
|
|
sampling_locations = sampling_locations.float() |
|
|
|
attention_weights = attention_weights.float() |
|
|
|
attention_weights = attention_weights.float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = MultiScaleDeformableAttnFunction.apply( |
|
|
|
output = MultiScaleDeformableAttnFunction.apply( |
|
|
|
value, |
|
|
|
value, |
|
|
|
spatial_shapes, |
|
|
|
spatial_shapes, |
|
|
@ -416,4 +407,3 @@ def create_dummy_func(func, dependency, message=""): |
|
|
|
raise ImportError(err) |
|
|
|
raise ImportError(err) |
|
|
|
|
|
|
|
|
|
|
|
return _dummy |
|
|
|
return _dummy |
|
|
|
|
|
|
|
|
|
|
|