|
|
@ -7,6 +7,7 @@ from typing import Any, Dict, List |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
from groundingdino.util.slconfig import SLConfig |
|
|
|
from groundingdino.util.slconfig import SLConfig |
|
|
|
|
|
|
|
|
|
|
@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device): |
|
|
|
] |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str): |
|
|
|
def get_phrases_from_posmap( |
|
|
|
|
|
|
|
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer |
|
|
|
|
|
|
|
): |
|
|
|
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" |
|
|
|
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" |
|
|
|
if posmap.dim() == 1: |
|
|
|
if posmap.dim() == 1: |
|
|
|
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() |
|
|
|
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() |
|
|
|
words_list = caption.split() |
|
|
|
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx] |
|
|
|
|
|
|
|
return tokenizer.decode(token_ids) |
|
|
|
# build word idx list |
|
|
|
|
|
|
|
words_idx_used_list = [] |
|
|
|
|
|
|
|
for idx in non_zero_idx: |
|
|
|
|
|
|
|
word_idx = tokenlized.token_to_word(idx) |
|
|
|
|
|
|
|
if word_idx is not None: |
|
|
|
|
|
|
|
words_idx_used_list.append(word_idx) |
|
|
|
|
|
|
|
words_idx_used_list = set(words_idx_used_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# build phrase |
|
|
|
|
|
|
|
words_used_list = [] |
|
|
|
|
|
|
|
for idx, word in enumerate(words_list): |
|
|
|
|
|
|
|
if idx in words_idx_used_list: |
|
|
|
|
|
|
|
words_used_list.append(word) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentence_res = " ".join(words_used_list) |
|
|
|
|
|
|
|
return sentence_res |
|
|
|
|
|
|
|
else: |
|
|
|
else: |
|
|
|
raise NotImplementedError("posmap must be 1-dim") |
|
|
|
raise NotImplementedError("posmap must be 1-dim") |
|
|
|