Compare commits

...

1 Commits

Author SHA1 Message Date
SkalskiP 0c2931b8cd Test fix for #11 2 years ago
  1. 2
      demo/inference_on_a_image.py
  2. 2
      groundingdino/util/inference.py
  3. 25
      groundingdino/util/utils.py

@ -108,7 +108,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, caption)
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:

@ -71,7 +71,7 @@ def predict(
tokenized = tokenizer(caption)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '')
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit
in logits
]

@ -7,6 +7,7 @@ from typing import Any, Dict, List
import numpy as np
import torch
from transformers import AutoTokenizer
from groundingdino.util.slconfig import SLConfig
@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device):
]
def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str):
def get_phrases_from_posmap(
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
):
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
if posmap.dim() == 1:
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
words_list = caption.split()
# build word idx list
words_idx_used_list = []
for idx in non_zero_idx:
word_idx = tokenlized.token_to_word(idx)
if word_idx is not None:
words_idx_used_list.append(word_idx)
words_idx_used_list = set(words_idx_used_list)
# build phrase
words_used_list = []
for idx, word in enumerate(words_list):
if idx in words_idx_used_list:
words_used_list.append(word)
sentence_res = " ".join(words_used_list)
return sentence_res
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
return tokenizer.decode(token_ids)
else:
raise NotImplementedError("posmap must be 1-dim")

Loading…
Cancel
Save