Compare commits

..

19 Commits

Author SHA1 Message Date
Shilong Liu 9dac4c605b
fix windows bugs (#30) 2 years ago
SlongLiu 3bb2c86c9a update readme with gd-swinb hf links 2 years ago
SlongLiu d3bc35fdea update gligen 2 years ago
SlongLiu 15ade007a8 add grounding dino - B 2 years ago
Shilong Liu 22292c4b78
add grounding dino with stable diffusion for image editing (#20) 2 years ago
rentainhe 4c8f9206b6 refine readme 2 years ago
rentainhe 97ad9935ac add grounded-segment-anything 2 years ago
George Pearse e93548c805
Create create_coco_dataset.py (#17) 2 years ago
Piotr Skalski e45c11c4c3
⚙️ more compact inference API - single class to load, process and infer (#16) 2 years ago
Piotr Skalski f6b1145481
🎬 Add Roboflow YouTube video to README.md (#13) 2 years ago
SlongLiu 3023d1a26f fix bugs for CPU mode 2 years ago
SlongLiu a02cf79301 update readme 2 years ago
SlongLiu 67a3c1940d update readme 2 years ago
SlongLiu ac00bd4a36 add webUI 2 years ago
Piotr Skalski c974f60d73
Test fix for #11 (#12) 2 years ago
SlongLiu 858efccbad 1. fix warnings. \n 2. support CPU mode. \n 3. update README. 2 years ago
Piotr Skalski 2309f9f468
feature/first_batch_of_model_usability_upgrades (#9) 2 years ago
SlongLiu 12ef464f9e upadate README 2 years ago
Shilong Liu 3e7a8ca2dc
Release code (#2) 2 years ago
  1. 4
      .gitignore
  2. 133
      README.md
  3. 83
      demo/create_coco_dataset.py
  4. 125
      demo/gradio_app.py
  5. 703
      demo/image_editing_with_groundingdino_gligen.ipynb
  6. 524
      demo/image_editing_with_groundingdino_stablediffusion.ipynb
  7. 27
      demo/inference_on_a_image.py
  8. 43
      groundingdino/config/GroundingDINO_SwinB.cfg.py
  9. 4
      groundingdino/models/GroundingDINO/backbone/position_encoding.py
  10. 4
      groundingdino/models/GroundingDINO/ms_deform_attn.py
  11. 2
      groundingdino/models/GroundingDINO/utils.py
  12. 242
      groundingdino/util/inference.py
  13. 3
      groundingdino/util/slconfig.py
  14. 25
      groundingdino/util/utils.py
  15. 11
      requirements.txt

4
.gitignore vendored

@ -1,3 +1,7 @@
# IDE
.idea/
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

@ -1,62 +1,104 @@
# Grounding DINO
# :sauropod: Grounding DINO
---
Grounding DINO Methods | [![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/IDEA-Research/GroundingDINO)
[![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499)
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8)
Grounding DINO Demos |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk)
[![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/C4NqaRBz_Kw)
Extensions | [Grounding DINO with Segment Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything); [Grounding DINO with Stable Diffusion](demo/image_editing_with_groundingdino_stablediffusion.ipynb); [Grounding DINO with GLIGEN](demo/image_editing_with_groundingdino_gligen.ipynb)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499). Code will be available soon!
## Highlight
Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
## :bulb: Highlight
- **Open-Set Detection.** Detect **everything** with language!
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
<!-- [![Watch the video](https://i.imgur.com/vKb2F1B.png)](https://youtu.be/wxWDt5UiwY8)
<iframe width="560" height="315" src="https://youtu.be/wxWDt5UiwY8" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe> -->
## :fire: News
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
- **`2023/04/06`**: We build a new demo by marrying GroundingDINO with [Segment-Anything](https://github.com/facebookresearch/segment-anything) named **[Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)** aims to support segmentation in GroundingDINO.
- **`2023/03/28`**: A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)]
- **`2023/03/28`**: Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space!
- **`2023/03/27`**: Support CPU-only mode. Now the model can run on machines without GPUs.
- **`2023/03/25`**: A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)]
- **`2023/03/22`**: Code is available Now!
<details open>
<summary><font size="4">
Description
</font></summary>
<a href="https://arxiv.org/abs/2303.05499">Paper</a> introduction.
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
Marrying <a href="https://github.com/IDEA-Research/GroundingDINO">Grounding DINO</a> and <a href="https://github.com/gligen/GLIGEN">GLIGEN</a>
<img src="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GD_GLIGEN.png" alt="gd_gligen" width="100%">
</details>
## TODO List
<div>
<input type="checkbox" name="uchk" checked>
<label for="uchk">Release inference code and demo.</label>
</div>
<div>
<input type="checkbox" name="uchk" checked>
<label for="uchk">Release checkpoints.</label>
</div>
<div>
<input type="checkbox" name="uchk">
<label for="uchk">Grounding DINO with Stable Diffusion and GLIGEN demos.</label>
</div>
## Usage
### 1. Install
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
## :label: TODO
- [x] Release inference code and demo.
- [x] Release checkpoints.
- [x] Grounding DINO with Stable Diffusion and GLIGEN demos.
- [ ] Release training codes.
## :hammer_and_wrench: Install
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
```bash
pip install -e .
```
### 2. Run an inference demo
See the `demo/inference_on_a_image.py` for more details.
## :arrow_forward: Demo
```bash
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
-c /path/to/config \
-p /path/to/checkpoint \
-i .asset/cats.png \
-o "outputs/0" \
-t "cat ear."
-t "cat ear." \
[--cpu-only] # open it for cpu mode
```
See the `demo/inference_on_a_image.py` for more details.
**Web UI**
We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
**Notebooks**
- We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
- We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
## :luggage: Checkpoints
### Checkpoints
<!-- insert a table -->
<table>
<thead>
@ -67,6 +109,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
<th>Data</th>
<th>box AP on COCO</th>
<th>Checkpoint</th>
<th>Config</th>
</tr>
</thead>
<tbody>
@ -76,12 +119,23 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
<td>Swin-T</td>
<td>O365,GoldG,Cap4M</td>
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">link</a></td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">Github link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth">HF link</a></td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
</tr>
<tr>
<th>2</th>
<td>GroundingDINO-B</td>
<td>Swin-B</td>
<td>COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO</td>
<td>56.7 </td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth">Github link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth">HF link</a>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinB.cfg.py">link</a></td>
</tr>
</tbody>
</table>
## Results
## :medal_military: Results
<details open>
<summary><font size="4">
COCO Object Detection Results
@ -100,29 +154,28 @@ ODinW Object Detection Results
<summary><font size="4">
Marrying Grounding DINO with <a href="https://github.com/Stability-AI/StableDiffusion">Stable Diffusion</a> for Image Editing
</font></summary>
See our example <a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/image_editing_with_groundingdino_stablediffusion.ipynb">notebook</a> for more details.
<img src=".asset/GD_SD.png" alt="GD_SD" width="100%">
</details>
<details open>
<summary><font size="4">
Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing
Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing.
</font></summary>
See our example <a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/image_editing_with_groundingdino_gligen.ipynb">notebook</a> for more details.
<img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
</details>
## Model
## :sauropod: Model: Grounding DINO
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
![arch](.asset/arch.png)
# Links
## :hearts: Acknowledgement
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
@ -130,8 +183,10 @@ We also thank great previous work including DETR, Deformable DETR, SMCA, Conditi
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
# Bibtex
## :black_nib: Citation
If you find our work helpful for your research, please consider citing the following BibTeX entry.
```bibtex
@inproceedings{ShilongLiu2023GroundingDM,
title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},

@ -0,0 +1,83 @@
import typer
from groundingdino.util.inference import load_model, load_image, predict
from tqdm import tqdm
import torchvision
import torch
import fiftyone as fo
def main(
image_directory: str = 'test_grounding_dino',
text_prompt: str = 'bus, car',
box_threshold: float = 0.15,
text_threshold: float = 0.10,
export_dataset: bool = False,
view_dataset: bool = False,
export_annotated_images: bool = True,
weights_path : str = "groundingdino_swint_ogc.pth",
config_path: str = "../../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
subsample: int = None,
):
model = load_model(config_path, weights_path)
dataset = fo.Dataset.from_images_dir(image_directory)
samples = []
if subsample is not None:
if subsample < len(dataset):
dataset = dataset.take(subsample).clone()
for sample in tqdm(dataset):
image_source, image = load_image(sample.filepath)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
)
detections = []
for box, logit, phrase in zip(boxes, logits, phrases):
rel_box = torchvision.ops.box_convert(box, 'cxcywh', 'xywh')
detections.append(
fo.Detection(
label=phrase,
bounding_box=rel_box,
confidence=logit,
))
# Store detections in a field name of your choice
sample["detections"] = fo.Detections(detections=detections)
sample.save()
# loads the voxel fiftyone UI ready for viewing the dataset.
if view_dataset:
session = fo.launch_app(dataset)
session.wait()
# exports COCO dataset ready for training
if export_dataset:
dataset.export(
'coco_dataset',
dataset_type=fo.types.COCODetectionDataset,
)
# saves bounding boxes plotted on the input images to disk
if export_annotated_images:
dataset.draw_labels(
'images_with_bounding_boxes',
label_fields=['detections']
)
if __name__ == '__main__':
typer.run(main)

@ -0,0 +1,125 @@
import argparse
from functools import partial
import cv2
import requests
import os
from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import warnings
import torch
# prepare the environment
os.system("python setup.py build develop --user")
os.system("pip install packaging==21.3")
os.system("pip install gradio")
warnings.filterwarnings("ignore")
import gradio as gr
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import annotate, load_image, predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
# Use this command for evaluate the Grounding DINO model
config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location='cpu')
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def image_transform_grounding(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(init_image, None) # 3, h, w
return init_image, image
def image_transform_grounding_for_vis(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
])
image, _ = transform(init_image, None) # 3, h, w
return image
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
init_image = input_image.convert("RGB")
original_size = init_image.size
_, image_tensor = image_transform_grounding(init_image)
image_pil: Image = image_transform_grounding_for_vis(init_image)
# run grounidng
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
return image_with_box
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
args = parser.parse_args()
block = gr.Blocks().queue()
with block:
gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
gr.Markdown("### Open-World Detection with Grounding DINO")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="pil")
grounding_caption = gr.Textbox(label="Detection Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
with gr.Column():
gallery = gr.outputs.Image(
type="pil",
# label="grounding results"
).style(full_width=True, full_height=True)
# gallery = gr.Gallery(label="Generated images", show_label=False).style(
# grid=[1], height="auto", container=True, full_width=True, full_height=True)
run_button.click(fn=run_grounding, inputs=[
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -39,7 +39,13 @@ def plot_boxes_to_image(image_pil, tgt):
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
# draw.text((x0, y0), str(label), fill=color)
bbox = draw.textbbox((x0, y0), str(label))
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (x0, y0, w + x0, y0 + h)
# bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
draw.text((x0, y0), str(label), fill="white")
@ -63,9 +69,9 @@ def load_image(image_path):
return image_pil, image
def load_model(model_config_path, model_checkpoint_path):
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
args = SLConfig.fromfile(model_config_path)
args.device = "cuda"
args.device = "cuda" if not cpu_only else "cpu"
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@ -74,13 +80,14 @@ def load_model(model_config_path, model_checkpoint_path):
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.cuda()
image = image.cuda()
device = "cuda" if not cpu_only else "cpu"
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
@ -101,7 +108,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, caption)
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
@ -125,6 +132,8 @@ if __name__ == "__main__":
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
args = parser.parse_args()
# cfg
@ -141,14 +150,14 @@ if __name__ == "__main__":
# load image
image_pil, image = load_image(image_path)
# load model
model = load_model(config_file, checkpoint_path)
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
# run model
boxes_filt, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
)
# visualize pred

@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_B_384_22k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

@ -111,11 +111,11 @@ class PositionEmbeddingSineHW(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(

@ -25,7 +25,10 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_
try:
from groundingdino import _C
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
# helpers
@ -323,6 +326,7 @@ class MultiScaleDeformableAttention(nn.Module):
reference_points.shape[-1]
)
)
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:

@ -206,7 +206,7 @@ def gen_sineembed_for_position(pos_tensor):
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t

@ -0,0 +1,242 @@
from typing import Tuple, List
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from torchvision.ops import box_convert
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.misc import clean_state_dict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import get_phrases_from_posmap
# ----------------------------------------------------------------------------------------------------------------------
# OLD API
# ----------------------------------------------------------------------------------------------------------------------
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def predict(
model,
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit
in logits
]
return boxes, logits.max(dim=1)[0], phrases
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(phrases, logits)
]
box_annotator = sv.BoxAnnotator()
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame
# ----------------------------------------------------------------------------------------------------------------------
# NEW API
# ----------------------------------------------------------------------------------------------------------------------
class Model:
def __init__(
self,
model_config_path: str,
model_checkpoint_path: str,
device: str = "cuda"
):
self.model = load_model(
model_config_path=model_config_path,
model_checkpoint_path=model_checkpoint_path,
device=device
).to(device)
self.device = device
def predict_with_caption(
self,
image: np.ndarray,
caption: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25
) -> Tuple[sv.Detections, List[str]]:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections, labels = model.predict_with_caption(
image=image,
caption=caption,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
"""
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold)
source_h, source_w, _ = image.shape
detections = Model.post_process_result(
source_h=source_h,
source_w=source_w,
boxes=boxes,
logits=logits)
return detections, phrases
def predict_with_classes(
self,
image: np.ndarray,
classes: List[str],
box_threshold: float,
text_threshold: float
) -> sv.Detections:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections = model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections)
"""
caption = ", ".join(classes)
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold)
source_h, source_w, _ = image.shape
detections = Model.post_process_result(
source_h=source_h,
source_w=source_w,
boxes=boxes,
logits=logits)
class_id = Model.phrases2classes(phrases=phrases, classes=classes)
detections.class_id = class_id
return detections
@staticmethod
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
image_transformed, _ = transform(image_pillow, None)
return image_transformed
@staticmethod
def post_process_result(
source_h: int,
source_w: int,
boxes: torch.Tensor,
logits: torch.Tensor
) -> sv.Detections:
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
confidence = logits.numpy()
return sv.Detections(xyxy=xyxy, confidence=confidence)
@staticmethod
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
class_ids = []
for phrase in phrases:
try:
class_ids.append(classes.index(phrase))
except ValueError:
class_ids.append(None)
return np.array(class_ids)

@ -2,6 +2,7 @@
# Modified from mmcv
# ==========================================================
import ast
import os
import os.path as osp
import shutil
import sys
@ -80,6 +81,8 @@ class SLConfig(object):
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
temp_config_name = osp.basename(temp_config_file.name)
if os.name == 'nt':
temp_config_file.close()
shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)

@ -7,6 +7,7 @@ from typing import Any, Dict, List
import numpy as np
import torch
from transformers import AutoTokenizer
from groundingdino.util.slconfig import SLConfig
@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device):
]
def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str):
def get_phrases_from_posmap(
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
):
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
if posmap.dim() == 1:
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
words_list = caption.split()
# build word idx list
words_idx_used_list = []
for idx in non_zero_idx:
word_idx = tokenlized.token_to_word(idx)
if word_idx is not None:
words_idx_used_list.append(word_idx)
words_idx_used_list = set(words_idx_used_list)
# build phrase
words_used_list = []
for idx, word in enumerate(words_list):
if idx in words_idx_used_list:
words_used_list.append(word)
sentence_res = " ".join(words_used_list)
return sentence_res
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
return tokenizer.decode(token_ids)
else:
raise NotImplementedError("posmap must be 1-dim")

@ -1 +1,10 @@
transformers==4.5.1
torch
torchvision
transformers
addict
yapf
timm
numpy
opencv-python
supervision==0.4.0
pycocotools
Loading…
Cancel
Save