You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.4 KiB
100 lines
3.4 KiB
import os |
|
import random |
|
from typing import List |
|
|
|
import torch |
|
|
|
|
|
def create_positive_map_from_span(tokenized, token_span, max_text_len=256): |
|
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j |
|
Input: |
|
- tokenized: |
|
- input_ids: Tensor[1, ntokens] |
|
- attention_mask: Tensor[1, ntokens] |
|
- token_span: list with length num_boxes. |
|
- each item: [start_idx, end_idx] |
|
""" |
|
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) |
|
for j, tok_list in enumerate(token_span): |
|
for (beg, end) in tok_list: |
|
beg_pos = tokenized.char_to_token(beg) |
|
end_pos = tokenized.char_to_token(end - 1) |
|
if beg_pos is None: |
|
try: |
|
beg_pos = tokenized.char_to_token(beg + 1) |
|
if beg_pos is None: |
|
beg_pos = tokenized.char_to_token(beg + 2) |
|
except: |
|
beg_pos = None |
|
if end_pos is None: |
|
try: |
|
end_pos = tokenized.char_to_token(end - 2) |
|
if end_pos is None: |
|
end_pos = tokenized.char_to_token(end - 3) |
|
except: |
|
end_pos = None |
|
if beg_pos is None or end_pos is None: |
|
continue |
|
|
|
assert beg_pos is not None and end_pos is not None |
|
if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE": |
|
positive_map[j, beg_pos] = 1 |
|
break |
|
else: |
|
positive_map[j, beg_pos : end_pos + 1].fill_(1) |
|
|
|
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) |
|
|
|
|
|
def build_captions_and_token_span(cat_list, force_lowercase): |
|
""" |
|
Return: |
|
captions: str |
|
cat2tokenspan: dict |
|
{ |
|
'dog': [[0, 2]], |
|
... |
|
} |
|
""" |
|
|
|
cat2tokenspan = {} |
|
captions = "" |
|
for catname in cat_list: |
|
class_name = catname |
|
if force_lowercase: |
|
class_name = class_name.lower() |
|
if "/" in class_name: |
|
class_name_list: List = class_name.strip().split("/") |
|
class_name_list.append(class_name) |
|
class_name: str = random.choice(class_name_list) |
|
|
|
tokens_positive_i = [] |
|
subnamelist = [i.strip() for i in class_name.strip().split(" ")] |
|
for subname in subnamelist: |
|
if len(subname) == 0: |
|
continue |
|
if len(captions) > 0: |
|
captions = captions + " " |
|
strat_idx = len(captions) |
|
end_idx = strat_idx + len(subname) |
|
tokens_positive_i.append([strat_idx, end_idx]) |
|
captions = captions + subname |
|
|
|
if len(tokens_positive_i) > 0: |
|
captions = captions + " ." |
|
cat2tokenspan[class_name] = tokens_positive_i |
|
|
|
return captions, cat2tokenspan |
|
|
|
|
|
def build_id2posspan_and_caption(category_dict: dict): |
|
"""Build id2pos_span and caption from category_dict |
|
|
|
Args: |
|
category_dict (dict): category_dict |
|
""" |
|
cat_list = [item["name"].lower() for item in category_dict] |
|
id2catname = {item["id"]: item["name"].lower() for item in category_dict} |
|
caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True) |
|
id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()} |
|
return id2posspan, caption
|
|
|