diff --git a/README.md b/README.md
index 9683193..ad9ce97 100644
--- a/README.md
+++ b/README.md
@@ -119,6 +119,7 @@ PaddleRS具有以下五大特色:
ReduceDim
SelectBand
RandomSwap
+ AppendIndex
...
@@ -138,6 +139,17 @@ PaddleRS具有以下五大特色:
辐射校正
...
+ 数据后处理
+
+ - 建筑边界规则化
+ - 道路断线连接
+ - ...
+
+ 数据可视化
+
遥感场景分类
@@ -177,8 +189,10 @@ PaddleRS目录树中关键部分如下:
│ ├── datasets # 数据集接口实现
│ ├── models # 视觉模型实现
│ ├── tasks # 训练器实现
-│ └── transforms # 数据预处理/数据增强实现
+│ ├── transforms # 数据预处理/数据增强实现
+│ └── utils # 数据下载/可视化/后处理等
├── tools # 遥感影像处理工具集
+├── examples # 相关实践案例
└── tutorials
└── train # 模型训练教程
```
diff --git a/paddlers/utils/postprocs/__init__.py b/paddlers/utils/postprocs/__init__.py
index 998a80f..36435b0 100644
--- a/paddlers/utils/postprocs/__init__.py
+++ b/paddlers/utils/postprocs/__init__.py
@@ -13,3 +13,4 @@
# limitations under the License.
from .regularization import building_regularization
+from .connection import cut_road_connection
diff --git a/paddlers/utils/postprocs/connection.py b/paddlers/utils/postprocs/connection.py
new file mode 100644
index 0000000..5dd4b21
--- /dev/null
+++ b/paddlers/utils/postprocs/connection.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import warnings
+
+import cv2
+import numpy as np
+from skimage import morphology
+from scipy import ndimage, optimize
+
+with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ from sklearn import metrics
+ from sklearn.cluster import KMeans
+
+from .utils import prepro_mask, calc_distance
+
+
+def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
+ """
+ Connecting cut road lines.
+
+ The original article refers to
+ Wang B, Chen Z, et al. "Road extraction of high-resolution satellite remote sensing images in U-Net network with consideration of connectivity."
+ (http://hgs.publish.founderss.cn/thesisDetails?columnId=4759509).
+
+ This algorithm has no public code.
+ The implementation procedure refers to original article,
+ and it is not fully consistent with the article:
+ 1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. In this implementation, we use the k that reports the highest silhouette score.
+ 2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
+
+ Args:
+ mask (np.ndarray): Mask of road.
+ line_width (int, optional): Width of the line used for patching.
+ . Default is 6.
+
+ Returns:
+ np.ndarray: Mask of road after connecting cut road lines.
+ """
+ mask = prepro_mask(mask)
+ skeleton = morphology.skeletonize(mask).astype("uint8")
+ break_points = _find_breakpoint(skeleton)
+ labels = _k_means(break_points)
+ match_points = _get_match_points(break_points, labels)
+ res = _draw_curve(mask, skeleton, match_points, line_width)
+ return res
+
+
+def _find_breakpoint(skeleton):
+ kernel_3x3 = np.ones((3, 3), dtype="uint8")
+ k3 = ndimage.convolve(skeleton, kernel_3x3)
+ point_map = np.zeros_like(k3)
+ point_map[k3 == 2] = 1
+ point_map *= skeleton * 255
+ # boundary filtering
+ filter_w = 5
+ cropped = point_map[filter_w:-filter_w, filter_w:-filter_w]
+ padded = np.pad(cropped, (filter_w, filter_w), mode="constant")
+ breakpoints = np.column_stack(np.where(padded == 255))
+ return breakpoints
+
+
+def _k_means(data):
+ silhouette_int = -1 # threshold
+ labels = None
+ for k in range(2, data.shape[0]):
+ kms = KMeans(k, random_state=66)
+ labels_tmp = kms.fit_predict(data) # train
+ silhouette = metrics.silhouette_score(data, labels_tmp)
+ if silhouette > silhouette_int: # better
+ silhouette_int = silhouette
+ labels = labels_tmp
+ return labels
+
+
+def _get_match_points(break_points, labels):
+ match_points = {}
+ for point, lab in zip(break_points, labels):
+ if lab in match_points.keys():
+ match_points[lab].append(point)
+ else:
+ match_points[lab] = [point]
+ return match_points
+
+
+def _draw_curve(mask, skeleton, match_points, line_width):
+ result = mask * 255
+ for v in match_points.values():
+ p_num = len(v)
+ if p_num == 2:
+ points_list = _curve_backtracking(v, skeleton)
+ if points_list is not None:
+ result = _broken_wire_repair(result, points_list, line_width)
+ elif p_num == 3:
+ sim_v = list(itertools.combinations(v, 2))
+ min_di = 1e6
+ for vij in sim_v:
+ di = calc_distance(vij[0][np.newaxis], vij[1][np.newaxis])
+ if di < min_di:
+ vv = vij
+ min_di = di
+ points_list = _curve_backtracking(vv, skeleton)
+ if points_list is not None:
+ result = _broken_wire_repair(result, points_list, line_width)
+ return result
+
+
+def _curve_backtracking(add_lines, skeleton):
+ points_list = []
+ p1 = add_lines[0]
+ p2 = add_lines[1]
+ bpk1, ps1 = _calc_angle_by_road(p1, skeleton)
+ bpk2, ps2 = _calc_angle_by_road(p2, skeleton)
+ if _check_angle(bpk1, bpk2):
+ points_list.append((
+ np.array(
+ ps1, dtype="int64"),
+ add_lines[0],
+ add_lines[1],
+ np.array(
+ ps2, dtype="int64"), ))
+ return points_list
+ else:
+ return None
+
+
+def _broken_wire_repair(mask, points_list, line_width):
+ d_mask = mask.copy()
+ for points in points_list:
+ nx, ny = _line_cubic(points)
+ for i in range(len(nx) - 1):
+ loc_p1 = (int(ny[i]), int(nx[i]))
+ loc_p2 = (int(ny[i + 1]), int(nx[i + 1]))
+ cv2.line(d_mask, loc_p1, loc_p2, [255], line_width)
+ return d_mask
+
+
+def _calc_angle_by_road(p, skeleton, num_circle=10):
+ def _not_in(p1, ps):
+ for p in ps:
+ if p1[0] == p[0] and p1[1] == p[1]:
+ return False
+ return True
+
+ h, w = skeleton.shape
+ tmp_p = p.tolist() if isinstance(p, np.ndarray) else p
+ tmp_p = [int(tmp_p[0]), int(tmp_p[1])]
+ ps = []
+ ps.append(tmp_p)
+ for _ in range(num_circle):
+ t_x = 0 if tmp_p[0] - 1 < 0 else tmp_p[0] - 1
+ t_y = 0 if tmp_p[1] - 1 < 0 else tmp_p[1] - 1
+ b_x = w if tmp_p[0] + 1 >= w else tmp_p[0] + 1
+ b_y = h if tmp_p[1] + 1 >= h else tmp_p[1] + 1
+ if int(np.sum(skeleton[t_x:b_x + 1, t_y:b_y + 1])) <= 3:
+ for i in range(t_x, b_x + 1):
+ for j in range(t_y, b_y + 1):
+ if skeleton[i, j] == 1:
+ pp = [int(i), int(j)]
+ if _not_in(pp, ps):
+ tmp_p = pp
+ ps.append(tmp_p)
+ # calc angle
+ theta = _angle_regression(ps)
+ dx, dy = np.cos(theta), np.sin(theta)
+ # calc direction
+ start = ps[-1]
+ end = ps[0]
+ if end[1] < start[1] or (end[1] == start[1] and end[0] < start[0]):
+ dx *= -1
+ dy *= -1
+ return [dx, dy], start
+
+
+def _angle_regression(datas):
+ def _linear(x: float, k: float, b: float) -> float:
+ return k * x + b
+
+ xs = []
+ ys = []
+ for data in datas:
+ xs.append(data[0])
+ ys.append(data[1])
+ xs_arr = np.array(xs)
+ ys_arr = np.array(ys)
+ # horizontal
+ if len(np.unique(xs_arr)) == 1:
+ theta = np.pi / 2
+ # vertical
+ elif len(np.unique(ys_arr)) == 1:
+ theta = 0
+ # cross calc
+ else:
+ k1, b1 = optimize.curve_fit(_linear, xs_arr, ys_arr)[0]
+ k2, b2 = optimize.curve_fit(_linear, ys_arr, xs_arr)[0]
+ err1 = 0
+ err2 = 0
+ for x, y in zip(xs_arr, ys_arr):
+ err1 += abs(_linear(x, k1, b1) - y) / np.sqrt(k1**2 + 1)
+ err2 += abs(_linear(y, k2, b2) - x) / np.sqrt(k2**2 + 1)
+ if err1 <= err2:
+ theta = (np.arctan(k1) + 2 * np.pi) % (2 * np.pi)
+ else:
+ theta = (np.pi / 2.0 - np.arctan(k2) + 2 * np.pi) % (2 * np.pi)
+ # [0, 180)
+ theta = theta * 180 / np.pi + 90
+ while theta >= 180:
+ theta -= 180
+ theta -= 90
+ if theta < 0:
+ theta += 180
+ return theta * np.pi / 180
+
+
+def _cubic(x, y):
+ def _func(x, a, b, c, d):
+ return a * x**3 + b * x**2 + c * x + d
+
+ arr_x = np.array(x).reshape((4, ))
+ arr_y = np.array(y).reshape((4, ))
+ popt1 = np.polyfit(arr_x, arr_y, 3)
+ popt2 = np.polyfit(arr_y, arr_x, 3)
+ x_min = np.min(arr_x)
+ x_max = np.max(arr_x)
+ y_min = np.min(arr_y)
+ y_max = np.max(arr_y)
+ nx = np.arange(x_min, x_max + 1, 1)
+ y_estimate = [_func(i, popt1[0], popt1[1], popt1[2], popt1[3]) for i in nx]
+ ny = np.arange(y_min, y_max + 1, 1)
+ x_estimate = [_func(i, popt2[0], popt2[1], popt2[2], popt2[3]) for i in ny]
+ if np.max(y_estimate) - np.min(y_estimate) <= np.max(x_estimate) - np.min(
+ x_estimate):
+ return nx, y_estimate
+ else:
+ return x_estimate, ny
+
+
+def _line_cubic(points):
+ xs = []
+ ys = []
+ for p in points:
+ x, y = p
+ xs.append(x)
+ ys.append(y)
+ nx, ny = _cubic(xs, ys)
+ return nx, ny
+
+
+def _get_theta(dy, dx):
+ theta = np.arctan2(dy, dx) * 180 / np.pi
+ if theta < 0.0:
+ theta = 360.0 - abs(theta)
+ return float(theta)
+
+
+def _check_angle(bpk1, bpk2, ang_threshold=90):
+ af1 = _get_theta(bpk1[0], bpk1[1])
+ af2 = _get_theta(bpk2[0], bpk2[1])
+ ang_diff = abs(af1 - af2)
+ if ang_diff > 180:
+ ang_diff = 360 - ang_diff
+ if ang_diff > ang_threshold:
+ return True
+ else:
+ return False
diff --git a/paddlers/utils/postprocs/regularization.py b/paddlers/utils/postprocs/regularization.py
index dc6060b..be63b72 100644
--- a/paddlers/utils/postprocs/regularization.py
+++ b/paddlers/utils/postprocs/regularization.py
@@ -13,11 +13,11 @@
# limitations under the License.
import math
+
import cv2
import numpy as np
-from .utils import (calc_distance, calc_angle, calc_azimuth, rotation, line,
- intersection, calc_distance_between_lines,
- calc_project_in_line)
+
+from .utils import prepro_mask, calc_distance
S = 20
TD = 3
@@ -52,15 +52,7 @@ def building_regularization(mask: np.ndarray, W: int=32) -> np.ndarray:
np.ndarray: Mask of building after regularized.
"""
# check and pro processing
- mask_shape = mask.shape
- if len(mask_shape) != 2:
- mask = mask[..., 0]
- mask = cv2.medianBlur(mask, 5)
- class_num = len(np.unique(mask))
- if class_num != 2:
- _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
- cv2.THRESH_OTSU)
- mask = np.clip(mask, 0, 1).astype("uint8") # 0-255 / 0-1 -> 0-1
+ mask = prepro_mask(mask)
mask_shape = mask.shape
# find contours
contours, hierarchys = cv2.findContours(mask, cv2.RETR_TREE,
@@ -115,7 +107,7 @@ def _coarse(contour, img_shape):
continue
# remove over-sharp angles with threshold α.
# remove over-smooth angles with threshold β.
- angle = calc_angle(last_point, current_point, next_point)
+ angle = _calc_angle(last_point, current_point, next_point)
if (ALPHA > angle or angle > BETA) and _inline_check(current_point,
img_shape):
contour = np.delete(contour, idx, axis=0)
@@ -143,7 +135,7 @@ def _fine(contour, W):
next_idx = (idx + 1) % p_number
next_point = contour[next_idx]
distance_list.append(calc_distance(current_point, next_point))
- azimuth_list.append(calc_azimuth(current_point, next_point))
+ azimuth_list.append(_calc_azimuth(current_point, next_point))
indexs_list.append((idx, next_idx))
# add the direction of the longest edge to the list of main direction.
longest_distance_idx = np.argmax(distance_list)
@@ -177,11 +169,11 @@ def _fine(contour, W):
abs_rotate_ang = abs(rotate_ang)
# adjust long edges according to the list and angles.
if abs_rotate_ang < DELTA or abs_rotate_ang > (180 - DELTA):
- rp1 = rotation(p1, pm, rotate_ang)
- rp2 = rotation(p2, pm, rotate_ang)
+ rp1 = _rotation(p1, pm, rotate_ang)
+ rp2 = _rotation(p2, pm, rotate_ang)
elif (90 - DELTA) < abs_rotate_ang < (90 + DELTA):
- rp1 = rotation(p1, pm, rotate_ang - 90)
- rp2 = rotation(p2, pm, rotate_ang - 90)
+ rp1 = _rotation(p1, pm, rotate_ang - 90)
+ rp2 = _rotation(p2, pm, rotate_ang - 90)
else:
rp1, rp2 = p1, p2
# adjust short edges (judged by a threshold θ) according to the list and angles.
@@ -189,11 +181,11 @@ def _fine(contour, W):
rotate_ang = md_used_list[-1] - azimuth
abs_rotate_ang = abs(rotate_ang)
if abs_rotate_ang < THETA or abs_rotate_ang > (180 - THETA):
- rp1 = rotation(p1, pm, rotate_ang)
- rp2 = rotation(p2, pm, rotate_ang)
+ rp1 = _rotation(p1, pm, rotate_ang)
+ rp2 = _rotation(p2, pm, rotate_ang)
else:
- rp1 = rotation(p1, pm, rotate_ang - 90)
- rp2 = rotation(p2, pm, rotate_ang - 90)
+ rp1 = _rotation(p1, pm, rotate_ang - 90)
+ rp2 = _rotation(p2, pm, rotate_ang - 90)
# contour_by_lines.extend([rp1, rp2])
contour_by_lines.append([rp1[0], rp2[0]])
correct_points = np.array(contour_by_lines)
@@ -208,35 +200,35 @@ def _fine(contour, W):
cur_edge_p2 = correct_points[idx][1]
next_edge_p1 = correct_points[next_idx][0]
next_edge_p2 = correct_points[next_idx][1]
- L1 = line(cur_edge_p1, cur_edge_p2)
- L2 = line(next_edge_p1, next_edge_p2)
- A1 = calc_azimuth([cur_edge_p1], [cur_edge_p2])
- A2 = calc_azimuth([next_edge_p1], [next_edge_p2])
+ L1 = _line(cur_edge_p1, cur_edge_p2)
+ L2 = _line(next_edge_p1, next_edge_p2)
+ A1 = _calc_azimuth([cur_edge_p1], [cur_edge_p2])
+ A2 = _calc_azimuth([next_edge_p1], [next_edge_p2])
dif_azi = abs(A1 - A2)
# find intersection point if not parallel
if (90 - DELTA) < dif_azi < (90 + DELTA):
- point_intersection = intersection(L1, L2)
+ point_intersection = _intersection(L1, L2)
if point_intersection is not None:
final_points.append(point_intersection)
# move or add lines when parallel
elif dif_azi < 1e-6:
- marg = calc_distance_between_lines(L1, L2)
+ marg = _calc_distance_between_lines(L1, L2)
if marg < D:
# move
- point_move = calc_project_in_line(next_edge_p1, cur_edge_p1,
- cur_edge_p2)
+ point_move = _calc_project_in_line(next_edge_p1, cur_edge_p1,
+ cur_edge_p2)
final_points.append(point_move)
# update next
correct_points[next_idx][0] = point_move
- correct_points[next_idx][1] = calc_project_in_line(
+ correct_points[next_idx][1] = _calc_project_in_line(
next_edge_p2, cur_edge_p1, cur_edge_p2)
else:
# add line
add_mid_point = (cur_edge_p2 + next_edge_p1) / 2
- rp1 = calc_project_in_line(add_mid_point, cur_edge_p1,
- cur_edge_p2)
- rp2 = calc_project_in_line(add_mid_point, next_edge_p1,
- next_edge_p2)
+ rp1 = _calc_project_in_line(add_mid_point, cur_edge_p1,
+ cur_edge_p2)
+ rp2 = _calc_project_in_line(add_mid_point, next_edge_p1,
+ next_edge_p2)
final_points.extend([rp1, rp2])
else:
final_points.extend(
@@ -262,3 +254,96 @@ def _fill(img, coarse_conts):
else:
cv2.fillPoly(result, [contour.astype(np.int32)], (255, 255, 255))
return result
+
+
+def _calc_angle(p1, vertex, p2):
+ x1, y1 = p1[0]
+ xv, yv = vertex[0]
+ x2, y2 = p2[0]
+ a = ((xv - x2) * (xv - x2) + (yv - y2) * (yv - y2))**0.5
+ b = ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))**0.5
+ c = ((x1 - xv) * (x1 - xv) + (y1 - yv) * (y1 - yv))**0.5
+ return math.degrees(math.acos((b**2 - a**2 - c**2) / (-2 * a * c)))
+
+
+def _calc_azimuth(p1, p2):
+ x1, y1 = p1[0]
+ x2, y2 = p2[0]
+ if y1 == y2:
+ return 0.0
+ if x1 == x2:
+ return 90.0
+ elif x1 < x2:
+ if y1 < y2:
+ ang = math.atan((y2 - y1) / (x2 - x1))
+ return math.degrees(ang)
+ else:
+ ang = math.atan((y1 - y2) / (x2 - x1))
+ return 180 - math.degrees(ang)
+ else: # x1 > x2
+ if y1 < y2:
+ ang = math.atan((y2 - y1) / (x1 - x2))
+ return 180 - math.degrees(ang)
+ else:
+ ang = math.atan((y1 - y2) / (x1 - x2))
+ return math.degrees(ang)
+
+
+def _rotation(point, center, angle):
+ if angle == 0:
+ return point
+ x, y = point[0]
+ cx, cy = center[0]
+ radian = math.radians(abs(angle))
+ if angle > 0: # clockwise
+ rx = (x - cx) * math.cos(radian) - (y - cy) * math.sin(radian) + cx
+ ry = (x - cx) * math.sin(radian) + (y - cy) * math.cos(radian) + cy
+ else:
+ rx = (x - cx) * math.cos(radian) + (y - cy) * math.sin(radian) + cx
+ ry = (y - cy) * math.cos(radian) - (x - cx) * math.sin(radian) + cy
+ return np.array([[rx, ry]])
+
+
+def _line(p1, p2):
+ A = (p1[1] - p2[1])
+ B = (p2[0] - p1[0])
+ C = (p1[0] * p2[1] - p2[0] * p1[1])
+ return A, B, -C
+
+
+def _intersection(L1, L2):
+ D = L1[0] * L2[1] - L1[1] * L2[0]
+ Dx = L1[2] * L2[1] - L1[1] * L2[2]
+ Dy = L1[0] * L2[2] - L1[2] * L2[0]
+ if D != 0:
+ x = Dx / D
+ y = Dy / D
+ return np.array([[x, y]])
+ else:
+ return None
+
+
+def _calc_distance_between_lines(L1, L2):
+ eps = 1e-16
+ A1, _, C1 = L1
+ A2, B2, C2 = L2
+ new_C1 = C1 / (A1 + eps)
+ new_A2 = 1
+ new_B2 = B2 / (A2 + eps)
+ new_C2 = C2 / (A2 + eps)
+ dist = (np.abs(new_C1 - new_C2)) / (
+ np.sqrt(new_A2 * new_A2 + new_B2 * new_B2) + eps)
+ return dist
+
+
+def _calc_project_in_line(point, line_point1, line_point2):
+ eps = 1e-16
+ m, n = point
+ x1, y1 = line_point1
+ x2, y2 = line_point2
+ F = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
+ x = (m * (x2 - x1) * (x2 - x1) + n * (y2 - y1) * (x2 - x1) +
+ (x1 * y2 - x2 * y1) * (y2 - y1)) / (F + eps)
+ y = (m * (x2 - x1) * (y2 - y1) + n * (y2 - y1) * (y2 - y1) +
+ (x2 * y1 - x1 * y2) * (x2 - x1)) / (F + eps)
+ return np.array([[x, y]])
diff --git a/paddlers/utils/postprocs/utils.py b/paddlers/utils/postprocs/utils.py
index 9393b0d..84cfa0f 100644
--- a/paddlers/utils/postprocs/utils.py
+++ b/paddlers/utils/postprocs/utils.py
@@ -13,101 +13,22 @@
# limitations under the License.
import numpy as np
-import math
+import cv2
-def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
- return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
-
-
-def calc_angle(p1: np.ndarray, vertex: np.ndarray, p2: np.ndarray) -> float:
- x1, y1 = p1[0]
- xv, yv = vertex[0]
- x2, y2 = p2[0]
- a = ((xv - x2) * (xv - x2) + (yv - y2) * (yv - y2))**0.5
- b = ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))**0.5
- c = ((x1 - xv) * (x1 - xv) + (y1 - yv) * (y1 - yv))**0.5
- return math.degrees(math.acos((b**2 - a**2 - c**2) / (-2 * a * c)))
-
-
-def calc_azimuth(p1: np.ndarray, p2: np.ndarray) -> float:
- x1, y1 = p1[0]
- x2, y2 = p2[0]
- if y1 == y2:
- return 0.0
- if x1 == x2:
- return 90.0
- elif x1 < x2:
- if y1 < y2:
- ang = math.atan((y2 - y1) / (x2 - x1))
- return math.degrees(ang)
- else:
- ang = math.atan((y1 - y2) / (x2 - x1))
- return 180 - math.degrees(ang)
- else: # x1 > x2
- if y1 < y2:
- ang = math.atan((y2 - y1) / (x1 - x2))
- return 180 - math.degrees(ang)
- else:
- ang = math.atan((y1 - y2) / (x1 - x2))
- return math.degrees(ang)
-
-
-def rotation(point: np.ndarray, center: np.ndarray, angle: float) -> np.ndarray:
- if angle == 0:
- return point
- x, y = point[0]
- cx, cy = center[0]
- radian = math.radians(abs(angle))
- if angle > 0: # clockwise
- rx = (x - cx) * math.cos(radian) - (y - cy) * math.sin(radian) + cx
- ry = (x - cx) * math.sin(radian) + (y - cy) * math.cos(radian) + cy
- else:
- rx = (x - cx) * math.cos(radian) + (y - cy) * math.sin(radian) + cx
- ry = (y - cy) * math.cos(radian) - (x - cx) * math.sin(radian) + cy
- return np.array([[rx, ry]])
+def prepro_mask(mask: np.ndarray):
+ mask_shape = mask.shape
+ if len(mask_shape) != 2:
+ mask = mask[..., 0]
+ mask = mask.astype("uint8")
+ mask = cv2.medianBlur(mask, 5)
+ class_num = len(np.unique(mask))
+ if class_num != 2:
+ _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
+ cv2.THRESH_OTSU)
+ mask = np.clip(mask, 0, 1).astype("uint8") # 0-255 / 0-1 -> 0-1
+ return mask
-def line(p1, p2):
- A = (p1[1] - p2[1])
- B = (p2[0] - p1[0])
- C = (p1[0] * p2[1] - p2[0] * p1[1])
- return A, B, -C
-
-
-def intersection(L1, L2):
- D = L1[0] * L2[1] - L1[1] * L2[0]
- Dx = L1[2] * L2[1] - L1[1] * L2[2]
- Dy = L1[0] * L2[2] - L1[2] * L2[0]
- if D != 0:
- x = Dx / D
- y = Dy / D
- return np.array([[x, y]])
- else:
- return None
-
-
-def calc_distance_between_lines(L1, L2):
- eps = 1e-16
- A1, _, C1 = L1
- A2, B2, C2 = L2
- new_C1 = C1 / (A1 + eps)
- new_A2 = 1
- new_B2 = B2 / (A2 + eps)
- new_C2 = C2 / (A2 + eps)
- dist = (np.abs(new_C1 - new_C2)) / (
- np.sqrt(new_A2 * new_A2 + new_B2 * new_B2) + eps)
- return dist
-
-
-def calc_project_in_line(point, line_point1, line_point2):
- eps = 1e-16
- m, n = point
- x1, y1 = line_point1
- x2, y2 = line_point2
- F = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
- x = (m * (x2 - x1) * (x2 - x1) + n * (y2 - y1) * (x2 - x1) +
- (x1 * y2 - x2 * y1) * (y2 - y1)) / (F + eps)
- y = (m * (x2 - x1) * (y2 - y1) + n * (y2 - y1) * (y2 - y1) +
- (x2 * y1 - x1 * y2) * (x2 - x1)) / (F + eps)
- return np.array([[x, y]])
+def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
+ return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
|