1. fix warnings. \n 2. support CPU mode. \n 3. update README.

fix/11_simplify_getting_output_labels
SlongLiu 2 years ago
parent 2309f9f468
commit 858efccbad
  1. 27
      README.md
  2. 25
      demo/inference_on_a_image.py
  3. 4
      groundingdino/models/GroundingDINO/backbone/position_encoding.py
  4. 6
      groundingdino/models/GroundingDINO/ms_deform_attn.py
  5. 2
      groundingdino/models/GroundingDINO/utils.py
  6. 3
      requirements.txt

@ -1,9 +1,10 @@
# Grounding DINO
[📃Paper](https://arxiv.org/abs/2303.05499) |
[📽Video](https://www.youtube.com/watch?v=wxWDt5UiwY8) |
[📯Demo on Colab](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) |
[🤗Demo on HF (Coming soon)]()
---
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
@ -20,8 +21,10 @@ Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.0
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
<!-- [![Watch the video](https://i.imgur.com/vKb2F1B.png)](https://youtu.be/wxWDt5UiwY8)
<iframe width="560" height="315" src="https://youtu.be/wxWDt5UiwY8" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe> -->
## News
[2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
[2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \
[2023/03/22] Code is available Now!
<details open>
<summary><font size="4">
@ -30,15 +33,18 @@ Description
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
</details>
## TODO
- [x] Release inference code and demo.
- [x] Release checkpoints.
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
- [ ] Release training codes.
## Install
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
```bash
pip install -e .
@ -46,15 +52,16 @@ pip install -e .
## Demo
See the `demo/inference_on_a_image.py` for more details.
```bash
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
-c /path/to/config \
-p /path/to/checkpoint \
-i .asset/cats.png \
-o "outputs/0" \
-t "cat ear."
-t "cat ear." \
[--cpu-only] # open it for cpu mode
```
See the `demo/inference_on_a_image.py` for more details.
## Checkpoints
@ -68,6 +75,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
<th>Data</th>
<th>box AP on COCO</th>
<th>Checkpoint</th>
<th>Config</th>
</tr>
</thead>
<tbody>
@ -78,6 +86,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
<td>O365,GoldG,Cap4M</td>
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">link</a></td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
</tr>
</tbody>
</table>

@ -39,7 +39,13 @@ def plot_boxes_to_image(image_pil, tgt):
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
# draw.text((x0, y0), str(label), fill=color)
bbox = draw.textbbox((x0, y0), str(label))
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (x0, y0, w + x0, y0 + h)
# bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
draw.text((x0, y0), str(label), fill="white")
@ -63,9 +69,9 @@ def load_image(image_path):
return image_pil, image
def load_model(model_config_path, model_checkpoint_path):
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
args = SLConfig.fromfile(model_config_path)
args.device = "cuda"
args.device = "cuda" if not cpu_only else "cpu"
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@ -74,13 +80,14 @@ def load_model(model_config_path, model_checkpoint_path):
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.cuda()
image = image.cuda()
device = "cuda" if not cpu_only else "cpu"
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
@ -125,6 +132,8 @@ if __name__ == "__main__":
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
args = parser.parse_args()
# cfg
@ -141,14 +150,14 @@ if __name__ == "__main__":
# load image
image_pil, image = load_image(image_path)
# load model
model = load_model(config_file, checkpoint_path)
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
# run model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
)
# visualize pred

@ -111,11 +111,11 @@ class PositionEmbeddingSineHW(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(

@ -25,7 +25,10 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_
from groundingdino import _C
try:
from groundingdino import _C
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
# helpers
@ -323,6 +326,7 @@ class MultiScaleDeformableAttention(nn.Module):
reference_points.shape[-1]
)
)
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:

@ -206,7 +206,7 @@ def gen_sineembed_for_position(pos_tensor):
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t

@ -6,4 +6,5 @@ yapf
timm
numpy
opencv-python
supervision==0.3.2
supervision==0.3.2
pycocotools
Loading…
Cancel
Save