You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
319 lines
12 KiB
319 lines
12 KiB
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@File : visualizer.py
|
||
|
@Time : 2022/04/05 11:39:33
|
||
|
@Author : Shilong Liu
|
||
|
@Contact : slongliu86@gmail.com
|
||
|
"""
|
||
|
|
||
|
import datetime
|
||
|
import os
|
||
|
|
||
|
import cv2
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from matplotlib import transforms
|
||
|
from matplotlib.collections import PatchCollection
|
||
|
from matplotlib.patches import Polygon
|
||
|
from pycocotools import mask as maskUtils
|
||
|
|
||
|
|
||
|
def renorm(
|
||
|
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||
|
) -> torch.FloatTensor:
|
||
|
# img: tensor(3,H,W) or tensor(B,3,H,W)
|
||
|
# return: same as img
|
||
|
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
|
||
|
if img.dim() == 3:
|
||
|
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
|
||
|
img.size(0),
|
||
|
str(img.size()),
|
||
|
)
|
||
|
img_perm = img.permute(1, 2, 0)
|
||
|
mean = torch.Tensor(mean)
|
||
|
std = torch.Tensor(std)
|
||
|
img_res = img_perm * std + mean
|
||
|
return img_res.permute(2, 0, 1)
|
||
|
else: # img.dim() == 4
|
||
|
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
|
||
|
img.size(1),
|
||
|
str(img.size()),
|
||
|
)
|
||
|
img_perm = img.permute(0, 2, 3, 1)
|
||
|
mean = torch.Tensor(mean)
|
||
|
std = torch.Tensor(std)
|
||
|
img_res = img_perm * std + mean
|
||
|
return img_res.permute(0, 3, 1, 2)
|
||
|
|
||
|
|
||
|
class ColorMap:
|
||
|
def __init__(self, basergb=[255, 255, 0]):
|
||
|
self.basergb = np.array(basergb)
|
||
|
|
||
|
def __call__(self, attnmap):
|
||
|
# attnmap: h, w. np.uint8.
|
||
|
# return: h, w, 4. np.uint8.
|
||
|
assert attnmap.dtype == np.uint8
|
||
|
h, w = attnmap.shape
|
||
|
res = self.basergb.copy()
|
||
|
res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
|
||
|
attn1 = attnmap.copy()[..., None] # h, w, 1
|
||
|
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
|
||
|
return res
|
||
|
|
||
|
|
||
|
def rainbow_text(x, y, ls, lc, **kw):
|
||
|
"""
|
||
|
Take a list of strings ``ls`` and colors ``lc`` and place them next to each
|
||
|
other, with text ls[i] being shown in color lc[i].
|
||
|
|
||
|
This example shows how to do both vertical and horizontal text, and will
|
||
|
pass all keyword arguments to plt.text, so you can set the font size,
|
||
|
family, etc.
|
||
|
"""
|
||
|
t = plt.gca().transData
|
||
|
fig = plt.gcf()
|
||
|
plt.show()
|
||
|
|
||
|
# horizontal version
|
||
|
for s, c in zip(ls, lc):
|
||
|
text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
|
||
|
text.draw(fig.canvas.get_renderer())
|
||
|
ex = text.get_window_extent()
|
||
|
t = transforms.offset_copy(text._transform, x=ex.width, units="dots")
|
||
|
|
||
|
# #vertical version
|
||
|
# for s,c in zip(ls,lc):
|
||
|
# text = plt.text(x,y," "+s+" ",color=c, transform=t,
|
||
|
# rotation=90,va='bottom',ha='center',**kw)
|
||
|
# text.draw(fig.canvas.get_renderer())
|
||
|
# ex = text.get_window_extent()
|
||
|
# t = transforms.offset_copy(text._transform, y=ex.height, units='dots')
|
||
|
|
||
|
|
||
|
class COCOVisualizer:
|
||
|
def __init__(self, coco=None, tokenlizer=None) -> None:
|
||
|
self.coco = coco
|
||
|
|
||
|
def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
|
||
|
"""
|
||
|
img: tensor(3, H, W)
|
||
|
tgt: make sure they are all on cpu.
|
||
|
must have items: 'image_id', 'boxes', 'size'
|
||
|
"""
|
||
|
plt.figure(dpi=dpi)
|
||
|
plt.rcParams["font.size"] = "5"
|
||
|
ax = plt.gca()
|
||
|
img = renorm(img).permute(1, 2, 0)
|
||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
ax.imshow(img)
|
||
|
|
||
|
self.addtgt(tgt)
|
||
|
|
||
|
if tgt is None:
|
||
|
image_id = 0
|
||
|
elif "image_id" not in tgt:
|
||
|
image_id = 0
|
||
|
else:
|
||
|
image_id = tgt["image_id"]
|
||
|
|
||
|
if caption is None:
|
||
|
savename = "{}/{}-{}.png".format(
|
||
|
savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
|
||
|
)
|
||
|
else:
|
||
|
savename = "{}/{}-{}-{}.png".format(
|
||
|
savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
|
||
|
)
|
||
|
print("savename: {}".format(savename))
|
||
|
os.makedirs(os.path.dirname(savename), exist_ok=True)
|
||
|
plt.savefig(savename)
|
||
|
plt.close()
|
||
|
|
||
|
def addtgt(self, tgt):
|
||
|
""" """
|
||
|
if tgt is None or not "boxes" in tgt:
|
||
|
ax = plt.gca()
|
||
|
|
||
|
if "caption" in tgt:
|
||
|
ax.set_title(tgt["caption"], wrap=True)
|
||
|
|
||
|
ax.set_axis_off()
|
||
|
return
|
||
|
|
||
|
ax = plt.gca()
|
||
|
H, W = tgt["size"]
|
||
|
numbox = tgt["boxes"].shape[0]
|
||
|
|
||
|
color = []
|
||
|
polygons = []
|
||
|
boxes = []
|
||
|
for box in tgt["boxes"].cpu():
|
||
|
unnormbbox = box * torch.Tensor([W, H, W, H])
|
||
|
unnormbbox[:2] -= unnormbbox[2:] / 2
|
||
|
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
|
||
|
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
|
||
|
poly = [
|
||
|
[bbox_x, bbox_y],
|
||
|
[bbox_x, bbox_y + bbox_h],
|
||
|
[bbox_x + bbox_w, bbox_y + bbox_h],
|
||
|
[bbox_x + bbox_w, bbox_y],
|
||
|
]
|
||
|
np_poly = np.array(poly).reshape((4, 2))
|
||
|
polygons.append(Polygon(np_poly))
|
||
|
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
|
||
|
color.append(c)
|
||
|
|
||
|
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
|
||
|
ax.add_collection(p)
|
||
|
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
|
||
|
ax.add_collection(p)
|
||
|
|
||
|
if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
|
||
|
assert (
|
||
|
len(tgt["strings_positive"]) == numbox
|
||
|
), f"{len(tgt['strings_positive'])} = {numbox}, "
|
||
|
for idx, strlist in enumerate(tgt["strings_positive"]):
|
||
|
cate_id = int(tgt["labels"][idx])
|
||
|
_string = str(cate_id) + ":" + " ".join(strlist)
|
||
|
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
|
||
|
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
|
||
|
ax.text(
|
||
|
bbox_x,
|
||
|
bbox_y,
|
||
|
_string,
|
||
|
color="black",
|
||
|
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
|
||
|
)
|
||
|
|
||
|
if "box_label" in tgt:
|
||
|
assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
|
||
|
for idx, bl in enumerate(tgt["box_label"]):
|
||
|
_string = str(bl)
|
||
|
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
|
||
|
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
|
||
|
ax.text(
|
||
|
bbox_x,
|
||
|
bbox_y,
|
||
|
_string,
|
||
|
color="black",
|
||
|
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
|
||
|
)
|
||
|
|
||
|
if "caption" in tgt:
|
||
|
ax.set_title(tgt["caption"], wrap=True)
|
||
|
# plt.figure()
|
||
|
# rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
|
||
|
# ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])
|
||
|
|
||
|
if "attn" in tgt:
|
||
|
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
if isinstance(tgt["attn"], tuple):
|
||
|
tgt["attn"] = [tgt["attn"]]
|
||
|
for item in tgt["attn"]:
|
||
|
attn_map, basergb = item
|
||
|
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
|
||
|
attn_map = (attn_map * 255).astype(np.uint8)
|
||
|
cm = ColorMap(basergb)
|
||
|
heatmap = cm(attn_map)
|
||
|
ax.imshow(heatmap)
|
||
|
ax.set_axis_off()
|
||
|
|
||
|
def showAnns(self, anns, draw_bbox=False):
|
||
|
"""
|
||
|
Display the specified annotations.
|
||
|
:param anns (array of object): annotations to display
|
||
|
:return: None
|
||
|
"""
|
||
|
if len(anns) == 0:
|
||
|
return 0
|
||
|
if "segmentation" in anns[0] or "keypoints" in anns[0]:
|
||
|
datasetType = "instances"
|
||
|
elif "caption" in anns[0]:
|
||
|
datasetType = "captions"
|
||
|
else:
|
||
|
raise Exception("datasetType not supported")
|
||
|
if datasetType == "instances":
|
||
|
ax = plt.gca()
|
||
|
ax.set_autoscale_on(False)
|
||
|
polygons = []
|
||
|
color = []
|
||
|
for ann in anns:
|
||
|
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
|
||
|
if "segmentation" in ann:
|
||
|
if type(ann["segmentation"]) == list:
|
||
|
# polygon
|
||
|
for seg in ann["segmentation"]:
|
||
|
poly = np.array(seg).reshape((int(len(seg) / 2), 2))
|
||
|
polygons.append(Polygon(poly))
|
||
|
color.append(c)
|
||
|
else:
|
||
|
# mask
|
||
|
t = self.imgs[ann["image_id"]]
|
||
|
if type(ann["segmentation"]["counts"]) == list:
|
||
|
rle = maskUtils.frPyObjects(
|
||
|
[ann["segmentation"]], t["height"], t["width"]
|
||
|
)
|
||
|
else:
|
||
|
rle = [ann["segmentation"]]
|
||
|
m = maskUtils.decode(rle)
|
||
|
img = np.ones((m.shape[0], m.shape[1], 3))
|
||
|
if ann["iscrowd"] == 1:
|
||
|
color_mask = np.array([2.0, 166.0, 101.0]) / 255
|
||
|
if ann["iscrowd"] == 0:
|
||
|
color_mask = np.random.random((1, 3)).tolist()[0]
|
||
|
for i in range(3):
|
||
|
img[:, :, i] = color_mask[i]
|
||
|
ax.imshow(np.dstack((img, m * 0.5)))
|
||
|
if "keypoints" in ann and type(ann["keypoints"]) == list:
|
||
|
# turn skeleton into zero-based index
|
||
|
sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
|
||
|
kp = np.array(ann["keypoints"])
|
||
|
x = kp[0::3]
|
||
|
y = kp[1::3]
|
||
|
v = kp[2::3]
|
||
|
for sk in sks:
|
||
|
if np.all(v[sk] > 0):
|
||
|
plt.plot(x[sk], y[sk], linewidth=3, color=c)
|
||
|
plt.plot(
|
||
|
x[v > 0],
|
||
|
y[v > 0],
|
||
|
"o",
|
||
|
markersize=8,
|
||
|
markerfacecolor=c,
|
||
|
markeredgecolor="k",
|
||
|
markeredgewidth=2,
|
||
|
)
|
||
|
plt.plot(
|
||
|
x[v > 1],
|
||
|
y[v > 1],
|
||
|
"o",
|
||
|
markersize=8,
|
||
|
markerfacecolor=c,
|
||
|
markeredgecolor=c,
|
||
|
markeredgewidth=2,
|
||
|
)
|
||
|
|
||
|
if draw_bbox:
|
||
|
[bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
|
||
|
poly = [
|
||
|
[bbox_x, bbox_y],
|
||
|
[bbox_x, bbox_y + bbox_h],
|
||
|
[bbox_x + bbox_w, bbox_y + bbox_h],
|
||
|
[bbox_x + bbox_w, bbox_y],
|
||
|
]
|
||
|
np_poly = np.array(poly).reshape((4, 2))
|
||
|
polygons.append(Polygon(np_poly))
|
||
|
color.append(c)
|
||
|
|
||
|
# p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
|
||
|
# ax.add_collection(p)
|
||
|
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
|
||
|
ax.add_collection(p)
|
||
|
elif datasetType == "captions":
|
||
|
for ann in anns:
|
||
|
print(ann["caption"])
|