From 3023d1a26f1f67e6af476eb51bd897a5191a9e8b Mon Sep 17 00:00:00 2001 From: SlongLiu Date: Tue, 28 Mar 2023 16:30:45 +0800 Subject: [PATCH] fix bugs for CPU mode --- README.md | 2 +- demo/gradio_app.py | 20 ++++++++++++-------- groundingdino/util/inference.py | 11 ++++++----- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 3f27a85..163ed78 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.0 - **Flexible.** Collaboration with Stable Diffusion for Image Editting. ## News -[2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Spce! \ +[2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Sapce! \ [2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\ [2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \ [2023/03/22] Code is available Now! diff --git a/demo/gradio_app.py b/demo/gradio_app.py index f1f193b..15e0832 100644 --- a/demo/gradio_app.py +++ b/demo/gradio_app.py @@ -7,16 +7,21 @@ from io import BytesIO from PIL import Image import numpy as np from pathlib import Path -import gradio as gr + import warnings import torch +# prepare the environment os.system("python setup.py build develop --user") os.system("pip install packaging==21.3") +os.system("pip install gradio") + + warnings.filterwarnings("ignore") +import gradio as gr from groundingdino.models import build_model from groundingdino.util.slconfig import SLConfig @@ -34,10 +39,10 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO" ckpt_filenmae = "groundingdino_swint_ogc.pth" -def load_model_hf(model_config_path, repo_id, filename): +def load_model_hf(model_config_path, repo_id, filename, device='cpu'): args = SLConfig.fromfile(model_config_path) - args.device = 'cuda' model = build_model(args) + args.device = device cache_file = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(cache_file, map_location='cpu') @@ -72,7 +77,7 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold) image_pil: Image = image_transform_grounding_for_vis(init_image) # run grounidng - boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold) + boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu') annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases) image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) @@ -83,14 +88,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") - parser.add_argument("--non-share", action="store_true", help="not share the app") + parser.add_argument("--share", action="store_true", help="share the app") args = parser.parse_args() - args.share = (not args.non_share) - block = gr.Blocks().queue() with block: - gr.Markdown("# Grounding DINO") + gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)") gr.Markdown("### Open-World Detection with Grounding DINO") with gr.Row(): @@ -117,5 +120,6 @@ if __name__ == "__main__": run_button.click(fn=run_grounding, inputs=[ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery]) + block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share) diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 73d3b9c..8168b96 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -21,9 +21,9 @@ def preprocess_caption(caption: str) -> str: return result + "." -def load_model(model_config_path: str, model_checkpoint_path: str): +def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): args = SLConfig.fromfile(model_config_path) - args.device = "cuda" + args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) @@ -50,12 +50,13 @@ def predict( image: torch.Tensor, caption: str, box_threshold: float, - text_threshold: float + text_threshold: float, + device: str = "cuda" ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: caption = preprocess_caption(caption=caption) - model = model.cuda() - image = image.cuda() + model = model.to(device) + image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption])