You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

76 lines
2.0 KiB

# code was heavily based on https://github.com/wtjiang98/PSGAN
# MIT License
# Copyright (c) 2020 Wentao Jiang
import os.path as osp
import numpy as np
import cv2
from PIL import Image
import paddle
import paddle.vision.transforms as T
from paddle.utils.download import get_path_from_url
import pickle
from .model import BiSeNet
BISENET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/bisenet.pdparams'
class FaceParser:
def __init__(self, device="cpu"):
self.mapper = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 0,
7: 11,
8: 12,
9: 0,
10: 6,
11: 8,
12: 7,
13: 9,
14: 13,
15: 0,
16: 0,
17: 10,
18: 0
}
#self.dict = paddle.to_tensor(mapper)
self.save_pth = get_path_from_url(BISENET_WEIGHT_URL,
osp.split(osp.realpath(__file__))[0])
self.net = BiSeNet(n_classes=19)
self.transforms = T.Compose([
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def parse(self, image):
assert image.shape[:2] == (512, 512)
image = image / 255.0
image = image.transpose((2, 0, 1))
image = self.transforms(image)
state_dict = paddle.load(self.save_pth)
self.net.set_dict(state_dict)
self.net.eval()
with paddle.no_grad():
image = paddle.to_tensor(image)
image = image.unsqueeze(0)
out = self.net(image)[0]
parsing = out.squeeze(0).argmax(0) #argmax(0).astype('float32')
parse_np = parsing.numpy()
h, w = parse_np.shape
result = np.zeros((h, w))
for i in range(h):
for j in range(w):
result[i][j] = self.mapper[parse_np[i][j]]
result = paddle.to_tensor(result).astype('float32')
return result