fix bugs for CPU mode

feature/add_roboflow_video_to_readme
SlongLiu 2 years ago
parent a02cf79301
commit 3023d1a26f
  1. 2
      README.md
  2. 20
      demo/gradio_app.py
  3. 11
      groundingdino/util/inference.py

@ -22,7 +22,7 @@ Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.0
- **Flexible.** Collaboration with Stable Diffusion for Image Editting. - **Flexible.** Collaboration with Stable Diffusion for Image Editting.
## News ## News
[2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Spce! \ [2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Sapce! \
[2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\ [2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
[2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \ [2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \
[2023/03/22] Code is available Now! [2023/03/22] Code is available Now!

@ -7,16 +7,21 @@ from io import BytesIO
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import gradio as gr
import warnings import warnings
import torch import torch
# prepare the environment
os.system("python setup.py build develop --user") os.system("python setup.py build develop --user")
os.system("pip install packaging==21.3") os.system("pip install packaging==21.3")
os.system("pip install gradio")
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import gradio as gr
from groundingdino.models import build_model from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig from groundingdino.util.slconfig import SLConfig
@ -34,10 +39,10 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth" ckpt_filenmae = "groundingdino_swint_ogc.pth"
def load_model_hf(model_config_path, repo_id, filename): def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path) args = SLConfig.fromfile(model_config_path)
args.device = 'cuda'
model = build_model(args) model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename) cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location='cpu') checkpoint = torch.load(cache_file, map_location='cpu')
@ -72,7 +77,7 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
image_pil: Image = image_transform_grounding_for_vis(init_image) image_pil: Image = image_transform_grounding_for_vis(init_image)
# run grounidng # run grounidng
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold) boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases) annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
@ -83,14 +88,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--non-share", action="store_true", help="not share the app") parser.add_argument("--share", action="store_true", help="share the app")
args = parser.parse_args() args = parser.parse_args()
args.share = (not args.non_share)
block = gr.Blocks().queue() block = gr.Blocks().queue()
with block: with block:
gr.Markdown("# Grounding DINO") gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
gr.Markdown("### Open-World Detection with Grounding DINO") gr.Markdown("### Open-World Detection with Grounding DINO")
with gr.Row(): with gr.Row():
@ -117,5 +120,6 @@ if __name__ == "__main__":
run_button.click(fn=run_grounding, inputs=[ run_button.click(fn=run_grounding, inputs=[
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery]) input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share) block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)

@ -21,9 +21,9 @@ def preprocess_caption(caption: str) -> str:
return result + "." return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str): def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
args = SLConfig.fromfile(model_config_path) args = SLConfig.fromfile(model_config_path)
args.device = "cuda" args.device = device
model = build_model(args) model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu") checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@ -50,12 +50,13 @@ def predict(
image: torch.Tensor, image: torch.Tensor,
caption: str, caption: str,
box_threshold: float, box_threshold: float,
text_threshold: float text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption) caption = preprocess_caption(caption=caption)
model = model.cuda() model = model.to(device)
image = image.cuda() image = image.to(device)
with torch.no_grad(): with torch.no_grad():
outputs = model(image[None], captions=[caption]) outputs = model(image[None], captions=[caption])

Loading…
Cancel
Save