diff --git a/.gitignore b/.gitignore
index 9798d5a..fdc0e71 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,7 @@
+# IDE
+.idea/
+.vscode/
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/README.md b/README.md
index cfed998..349dd68 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,21 @@
+# Grounding DINO
+
+---
+
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
-# Grounding DINO
+
Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
## Highlight
+
- **Open-Set Detection.** Detect **everything** with language!
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
@@ -23,21 +30,22 @@ Description
-## TODO List
+## TODO
- [x] Release inference code and demo.
- [x] Release checkpoints.
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
+## Install
-## Usage
-### 1. Install
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
+
```bash
pip install -e .
```
-### 2. Run an inference demo
+## Demo
+
See the `demo/inference_on_a_image.py` for more details.
```bash
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
@@ -48,7 +56,8 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
-t "cat ear."
```
-### Checkpoints
+## Checkpoints
+
@@ -74,6 +83,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
## Results
+
COCO Object Detection Results
@@ -102,11 +112,6 @@ Marrying Grounding DINO with GLIGEN
-
-
-
-
-
## Model
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
@@ -114,7 +119,8 @@ Includes: a text backbone, an image backbone, a feature enhancer, a language-gui
![arch](.asset/arch.png)
-# Links
+## Acknowledgement
+
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
@@ -122,8 +128,10 @@ We also thank great previous work including DETR, Deformable DETR, SMCA, Conditi
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
-# Bibtex
+## Citation
+
If you find our work helpful for your research, please consider citing the following BibTeX entry.
+
```bibtex
@inproceedings{ShilongLiu2023GroundingDM,
title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py
new file mode 100644
index 0000000..72c84ed
--- /dev/null
+++ b/groundingdino/util/inference.py
@@ -0,0 +1,97 @@
+from typing import Tuple, List
+
+import cv2
+import numpy as np
+import supervision as sv
+import torch
+from PIL import Image
+from torchvision.ops import box_convert
+
+import groundingdino.datasets.transforms as T
+from groundingdino.models import build_model
+from groundingdino.util.misc import clean_state_dict
+from groundingdino.util.slconfig import SLConfig
+from groundingdino.util.utils import get_phrases_from_posmap
+
+
+def preprocess_caption(caption: str) -> str:
+ result = caption.lower().strip()
+ if result.endswith("."):
+ return result
+ return result + "."
+
+
+def load_model(model_config_path: str, model_checkpoint_path: str):
+ args = SLConfig.fromfile(model_config_path)
+ args.device = "cuda"
+ model = build_model(args)
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
+ model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
+ model.eval()
+ return model
+
+
+def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image_source = Image.open(image_path).convert("RGB")
+ image = np.asarray(image_source)
+ image_transformed, _ = transform(image_source, None)
+ return image, image_transformed
+
+
+def predict(
+ model,
+ image: torch.Tensor,
+ caption: str,
+ box_threshold: float,
+ text_threshold: float
+) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
+ caption = preprocess_caption(caption=caption)
+
+ model = model.cuda()
+ image = image.cuda()
+
+ with torch.no_grad():
+ outputs = model(image[None], captions=[caption])
+
+ prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
+ prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
+
+ mask = prediction_logits.max(dim=1)[0] > box_threshold
+ logits = prediction_logits[mask] # logits.shape = (n, 256)
+ boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
+
+ tokenizer = model.tokenizer
+ tokenized = tokenizer(caption)
+
+ phrases = [
+ get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '')
+ for logit
+ in logits
+ ]
+
+ return boxes, logits.max(dim=1)[0], phrases
+
+
+def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
+ h, w, _ = image_source.shape
+ boxes = boxes * torch.Tensor([w, h, w, h])
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
+ detections = sv.Detections(xyxy=xyxy)
+
+ labels = [
+ f"{phrase} {logit:.2f}"
+ for phrase, logit
+ in zip(phrases, logits)
+ ]
+
+ box_annotator = sv.BoxAnnotator()
+ annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
+ return annotated_frame
diff --git a/requirements.txt b/requirements.txt
index 5924562..63773b7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,9 @@
-transformers==4.5.1
\ No newline at end of file
+torch
+torchvision
+transformers
+addict
+yapf
+timm
+numpy
+opencv-python
+supervision==0.3.2
\ No newline at end of file