You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.4 KiB
101 lines
3.4 KiB
2 years ago
|
import os
|
||
|
import random
|
||
|
from typing import List
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
|
||
|
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j
|
||
|
Input:
|
||
|
- tokenized:
|
||
|
- input_ids: Tensor[1, ntokens]
|
||
|
- attention_mask: Tensor[1, ntokens]
|
||
|
- token_span: list with length num_boxes.
|
||
|
- each item: [start_idx, end_idx]
|
||
|
"""
|
||
|
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
|
||
|
for j, tok_list in enumerate(token_span):
|
||
|
for (beg, end) in tok_list:
|
||
|
beg_pos = tokenized.char_to_token(beg)
|
||
|
end_pos = tokenized.char_to_token(end - 1)
|
||
|
if beg_pos is None:
|
||
|
try:
|
||
|
beg_pos = tokenized.char_to_token(beg + 1)
|
||
|
if beg_pos is None:
|
||
|
beg_pos = tokenized.char_to_token(beg + 2)
|
||
|
except:
|
||
|
beg_pos = None
|
||
|
if end_pos is None:
|
||
|
try:
|
||
|
end_pos = tokenized.char_to_token(end - 2)
|
||
|
if end_pos is None:
|
||
|
end_pos = tokenized.char_to_token(end - 3)
|
||
|
except:
|
||
|
end_pos = None
|
||
|
if beg_pos is None or end_pos is None:
|
||
|
continue
|
||
|
|
||
|
assert beg_pos is not None and end_pos is not None
|
||
|
if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
|
||
|
positive_map[j, beg_pos] = 1
|
||
|
break
|
||
|
else:
|
||
|
positive_map[j, beg_pos : end_pos + 1].fill_(1)
|
||
|
|
||
|
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
|
||
|
|
||
|
|
||
|
def build_captions_and_token_span(cat_list, force_lowercase):
|
||
|
"""
|
||
|
Return:
|
||
|
captions: str
|
||
|
cat2tokenspan: dict
|
||
|
{
|
||
|
'dog': [[0, 2]],
|
||
|
...
|
||
|
}
|
||
|
"""
|
||
|
|
||
|
cat2tokenspan = {}
|
||
|
captions = ""
|
||
|
for catname in cat_list:
|
||
|
class_name = catname
|
||
|
if force_lowercase:
|
||
|
class_name = class_name.lower()
|
||
|
if "/" in class_name:
|
||
|
class_name_list: List = class_name.strip().split("/")
|
||
|
class_name_list.append(class_name)
|
||
|
class_name: str = random.choice(class_name_list)
|
||
|
|
||
|
tokens_positive_i = []
|
||
|
subnamelist = [i.strip() for i in class_name.strip().split(" ")]
|
||
|
for subname in subnamelist:
|
||
|
if len(subname) == 0:
|
||
|
continue
|
||
|
if len(captions) > 0:
|
||
|
captions = captions + " "
|
||
|
strat_idx = len(captions)
|
||
|
end_idx = strat_idx + len(subname)
|
||
|
tokens_positive_i.append([strat_idx, end_idx])
|
||
|
captions = captions + subname
|
||
|
|
||
|
if len(tokens_positive_i) > 0:
|
||
|
captions = captions + " ."
|
||
|
cat2tokenspan[class_name] = tokens_positive_i
|
||
|
|
||
|
return captions, cat2tokenspan
|
||
|
|
||
|
|
||
|
def build_id2posspan_and_caption(category_dict: dict):
|
||
|
"""Build id2pos_span and caption from category_dict
|
||
|
|
||
|
Args:
|
||
|
category_dict (dict): category_dict
|
||
|
"""
|
||
|
cat_list = [item["name"].lower() for item in category_dict]
|
||
|
id2catname = {item["id"]: item["name"].lower() for item in category_dict}
|
||
|
caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
|
||
|
id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
|
||
|
return id2posspan, caption
|