[Feat] Add New Postproc (Cut Road Connection) (#52)

*
own
Yizhou Chen 2 years ago committed by GitHub
parent 37677385af
commit 7390b9e6ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      README.md
  2. 1
      paddlers/utils/postprocs/__init__.py
  3. 278
      paddlers/utils/postprocs/connection.py
  4. 155
      paddlers/utils/postprocs/regularization.py
  5. 109
      paddlers/utils/postprocs/utils.py

@ -119,6 +119,7 @@ PaddleRS具有以下五大特色:
<li>ReduceDim</li> <li>ReduceDim</li>
<li>SelectBand</li> <li>SelectBand</li>
<li>RandomSwap</li> <li>RandomSwap</li>
<li>AppendIndex</li>
<li>...</li> <li>...</li>
</ul> </ul>
</td> </td>
@ -138,6 +139,17 @@ PaddleRS具有以下五大特色:
<li>辐射校正</li> <li>辐射校正</li>
<li>...</li> <li>...</li>
</ul> </ul>
<b>数据后处理</b><br>
<ul>
<li>建筑边界规则化</li>
<li>道路断线连接</li>
<li>...</li>
</ul>
<b>数据可视化</b><br>
<ul>
<li>地图-栅格可视化</li>
<li>...</li>
</ul>
</td> </td>
<td> <td>
<b>遥感场景分类</b><br> <b>遥感场景分类</b><br>
@ -177,8 +189,10 @@ PaddleRS目录树中关键部分如下:
│ ├── datasets # 数据集接口实现 │ ├── datasets # 数据集接口实现
│ ├── models # 视觉模型实现 │ ├── models # 视觉模型实现
│ ├── tasks # 训练器实现 │ ├── tasks # 训练器实现
│ └── transforms # 数据预处理/数据增强实现 │ ├── transforms # 数据预处理/数据增强实现
│ └── utils # 数据下载/可视化/后处理等
├── tools # 遥感影像处理工具集 ├── tools # 遥感影像处理工具集
├── examples # 相关实践案例
└── tutorials └── tutorials
└── train # 模型训练教程 └── train # 模型训练教程
``` ```

@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from .regularization import building_regularization from .regularization import building_regularization
from .connection import cut_road_connection

@ -0,0 +1,278 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import warnings
import cv2
import numpy as np
from skimage import morphology
from scipy import ndimage, optimize
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
from sklearn import metrics
from sklearn.cluster import KMeans
from .utils import prepro_mask, calc_distance
def cut_road_connection(mask: np.ndarray, line_width: int=6) -> np.ndarray:
"""
Connecting cut road lines.
The original article refers to
Wang B, Chen Z, et al. "Road extraction of high-resolution satellite remote sensing images in U-Net network with consideration of connectivity."
(http://hgs.publish.founderss.cn/thesisDetails?columnId=4759509).
This algorithm has no public code.
The implementation procedure refers to original article,
and it is not fully consistent with the article:
1. The way to determine the optimal number of clusters k used in k-means clustering is not described in the original article. In this implementation, we use the k that reports the highest silhouette score.
2. We unmark the breakpoints if the angle between the two road extensions is less than 90°.
Args:
mask (np.ndarray): Mask of road.
line_width (int, optional): Width of the line used for patching.
. Default is 6.
Returns:
np.ndarray: Mask of road after connecting cut road lines.
"""
mask = prepro_mask(mask)
skeleton = morphology.skeletonize(mask).astype("uint8")
break_points = _find_breakpoint(skeleton)
labels = _k_means(break_points)
match_points = _get_match_points(break_points, labels)
res = _draw_curve(mask, skeleton, match_points, line_width)
return res
def _find_breakpoint(skeleton):
kernel_3x3 = np.ones((3, 3), dtype="uint8")
k3 = ndimage.convolve(skeleton, kernel_3x3)
point_map = np.zeros_like(k3)
point_map[k3 == 2] = 1
point_map *= skeleton * 255
# boundary filtering
filter_w = 5
cropped = point_map[filter_w:-filter_w, filter_w:-filter_w]
padded = np.pad(cropped, (filter_w, filter_w), mode="constant")
breakpoints = np.column_stack(np.where(padded == 255))
return breakpoints
def _k_means(data):
silhouette_int = -1 # threshold
labels = None
for k in range(2, data.shape[0]):
kms = KMeans(k, random_state=66)
labels_tmp = kms.fit_predict(data) # train
silhouette = metrics.silhouette_score(data, labels_tmp)
if silhouette > silhouette_int: # better
silhouette_int = silhouette
labels = labels_tmp
return labels
def _get_match_points(break_points, labels):
match_points = {}
for point, lab in zip(break_points, labels):
if lab in match_points.keys():
match_points[lab].append(point)
else:
match_points[lab] = [point]
return match_points
def _draw_curve(mask, skeleton, match_points, line_width):
result = mask * 255
for v in match_points.values():
p_num = len(v)
if p_num == 2:
points_list = _curve_backtracking(v, skeleton)
if points_list is not None:
result = _broken_wire_repair(result, points_list, line_width)
elif p_num == 3:
sim_v = list(itertools.combinations(v, 2))
min_di = 1e6
for vij in sim_v:
di = calc_distance(vij[0][np.newaxis], vij[1][np.newaxis])
if di < min_di:
vv = vij
min_di = di
points_list = _curve_backtracking(vv, skeleton)
if points_list is not None:
result = _broken_wire_repair(result, points_list, line_width)
return result
def _curve_backtracking(add_lines, skeleton):
points_list = []
p1 = add_lines[0]
p2 = add_lines[1]
bpk1, ps1 = _calc_angle_by_road(p1, skeleton)
bpk2, ps2 = _calc_angle_by_road(p2, skeleton)
if _check_angle(bpk1, bpk2):
points_list.append((
np.array(
ps1, dtype="int64"),
add_lines[0],
add_lines[1],
np.array(
ps2, dtype="int64"), ))
return points_list
else:
return None
def _broken_wire_repair(mask, points_list, line_width):
d_mask = mask.copy()
for points in points_list:
nx, ny = _line_cubic(points)
for i in range(len(nx) - 1):
loc_p1 = (int(ny[i]), int(nx[i]))
loc_p2 = (int(ny[i + 1]), int(nx[i + 1]))
cv2.line(d_mask, loc_p1, loc_p2, [255], line_width)
return d_mask
def _calc_angle_by_road(p, skeleton, num_circle=10):
def _not_in(p1, ps):
for p in ps:
if p1[0] == p[0] and p1[1] == p[1]:
return False
return True
h, w = skeleton.shape
tmp_p = p.tolist() if isinstance(p, np.ndarray) else p
tmp_p = [int(tmp_p[0]), int(tmp_p[1])]
ps = []
ps.append(tmp_p)
for _ in range(num_circle):
t_x = 0 if tmp_p[0] - 1 < 0 else tmp_p[0] - 1
t_y = 0 if tmp_p[1] - 1 < 0 else tmp_p[1] - 1
b_x = w if tmp_p[0] + 1 >= w else tmp_p[0] + 1
b_y = h if tmp_p[1] + 1 >= h else tmp_p[1] + 1
if int(np.sum(skeleton[t_x:b_x + 1, t_y:b_y + 1])) <= 3:
for i in range(t_x, b_x + 1):
for j in range(t_y, b_y + 1):
if skeleton[i, j] == 1:
pp = [int(i), int(j)]
if _not_in(pp, ps):
tmp_p = pp
ps.append(tmp_p)
# calc angle
theta = _angle_regression(ps)
dx, dy = np.cos(theta), np.sin(theta)
# calc direction
start = ps[-1]
end = ps[0]
if end[1] < start[1] or (end[1] == start[1] and end[0] < start[0]):
dx *= -1
dy *= -1
return [dx, dy], start
def _angle_regression(datas):
def _linear(x: float, k: float, b: float) -> float:
return k * x + b
xs = []
ys = []
for data in datas:
xs.append(data[0])
ys.append(data[1])
xs_arr = np.array(xs)
ys_arr = np.array(ys)
# horizontal
if len(np.unique(xs_arr)) == 1:
theta = np.pi / 2
# vertical
elif len(np.unique(ys_arr)) == 1:
theta = 0
# cross calc
else:
k1, b1 = optimize.curve_fit(_linear, xs_arr, ys_arr)[0]
k2, b2 = optimize.curve_fit(_linear, ys_arr, xs_arr)[0]
err1 = 0
err2 = 0
for x, y in zip(xs_arr, ys_arr):
err1 += abs(_linear(x, k1, b1) - y) / np.sqrt(k1**2 + 1)
err2 += abs(_linear(y, k2, b2) - x) / np.sqrt(k2**2 + 1)
if err1 <= err2:
theta = (np.arctan(k1) + 2 * np.pi) % (2 * np.pi)
else:
theta = (np.pi / 2.0 - np.arctan(k2) + 2 * np.pi) % (2 * np.pi)
# [0, 180)
theta = theta * 180 / np.pi + 90
while theta >= 180:
theta -= 180
theta -= 90
if theta < 0:
theta += 180
return theta * np.pi / 180
def _cubic(x, y):
def _func(x, a, b, c, d):
return a * x**3 + b * x**2 + c * x + d
arr_x = np.array(x).reshape((4, ))
arr_y = np.array(y).reshape((4, ))
popt1 = np.polyfit(arr_x, arr_y, 3)
popt2 = np.polyfit(arr_y, arr_x, 3)
x_min = np.min(arr_x)
x_max = np.max(arr_x)
y_min = np.min(arr_y)
y_max = np.max(arr_y)
nx = np.arange(x_min, x_max + 1, 1)
y_estimate = [_func(i, popt1[0], popt1[1], popt1[2], popt1[3]) for i in nx]
ny = np.arange(y_min, y_max + 1, 1)
x_estimate = [_func(i, popt2[0], popt2[1], popt2[2], popt2[3]) for i in ny]
if np.max(y_estimate) - np.min(y_estimate) <= np.max(x_estimate) - np.min(
x_estimate):
return nx, y_estimate
else:
return x_estimate, ny
def _line_cubic(points):
xs = []
ys = []
for p in points:
x, y = p
xs.append(x)
ys.append(y)
nx, ny = _cubic(xs, ys)
return nx, ny
def _get_theta(dy, dx):
theta = np.arctan2(dy, dx) * 180 / np.pi
if theta < 0.0:
theta = 360.0 - abs(theta)
return float(theta)
def _check_angle(bpk1, bpk2, ang_threshold=90):
af1 = _get_theta(bpk1[0], bpk1[1])
af2 = _get_theta(bpk2[0], bpk2[1])
ang_diff = abs(af1 - af2)
if ang_diff > 180:
ang_diff = 360 - ang_diff
if ang_diff > ang_threshold:
return True
else:
return False

@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
import math import math
import cv2 import cv2
import numpy as np import numpy as np
from .utils import (calc_distance, calc_angle, calc_azimuth, rotation, line,
intersection, calc_distance_between_lines, from .utils import prepro_mask, calc_distance
calc_project_in_line)
S = 20 S = 20
TD = 3 TD = 3
@ -52,15 +52,7 @@ def building_regularization(mask: np.ndarray, W: int=32) -> np.ndarray:
np.ndarray: Mask of building after regularized. np.ndarray: Mask of building after regularized.
""" """
# check and pro processing # check and pro processing
mask_shape = mask.shape mask = prepro_mask(mask)
if len(mask_shape) != 2:
mask = mask[..., 0]
mask = cv2.medianBlur(mask, 5)
class_num = len(np.unique(mask))
if class_num != 2:
_, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
cv2.THRESH_OTSU)
mask = np.clip(mask, 0, 1).astype("uint8") # 0-255 / 0-1 -> 0-1
mask_shape = mask.shape mask_shape = mask.shape
# find contours # find contours
contours, hierarchys = cv2.findContours(mask, cv2.RETR_TREE, contours, hierarchys = cv2.findContours(mask, cv2.RETR_TREE,
@ -115,7 +107,7 @@ def _coarse(contour, img_shape):
continue continue
# remove over-sharp angles with threshold α. # remove over-sharp angles with threshold α.
# remove over-smooth angles with threshold β. # remove over-smooth angles with threshold β.
angle = calc_angle(last_point, current_point, next_point) angle = _calc_angle(last_point, current_point, next_point)
if (ALPHA > angle or angle > BETA) and _inline_check(current_point, if (ALPHA > angle or angle > BETA) and _inline_check(current_point,
img_shape): img_shape):
contour = np.delete(contour, idx, axis=0) contour = np.delete(contour, idx, axis=0)
@ -143,7 +135,7 @@ def _fine(contour, W):
next_idx = (idx + 1) % p_number next_idx = (idx + 1) % p_number
next_point = contour[next_idx] next_point = contour[next_idx]
distance_list.append(calc_distance(current_point, next_point)) distance_list.append(calc_distance(current_point, next_point))
azimuth_list.append(calc_azimuth(current_point, next_point)) azimuth_list.append(_calc_azimuth(current_point, next_point))
indexs_list.append((idx, next_idx)) indexs_list.append((idx, next_idx))
# add the direction of the longest edge to the list of main direction. # add the direction of the longest edge to the list of main direction.
longest_distance_idx = np.argmax(distance_list) longest_distance_idx = np.argmax(distance_list)
@ -177,11 +169,11 @@ def _fine(contour, W):
abs_rotate_ang = abs(rotate_ang) abs_rotate_ang = abs(rotate_ang)
# adjust long edges according to the list and angles. # adjust long edges according to the list and angles.
if abs_rotate_ang < DELTA or abs_rotate_ang > (180 - DELTA): if abs_rotate_ang < DELTA or abs_rotate_ang > (180 - DELTA):
rp1 = rotation(p1, pm, rotate_ang) rp1 = _rotation(p1, pm, rotate_ang)
rp2 = rotation(p2, pm, rotate_ang) rp2 = _rotation(p2, pm, rotate_ang)
elif (90 - DELTA) < abs_rotate_ang < (90 + DELTA): elif (90 - DELTA) < abs_rotate_ang < (90 + DELTA):
rp1 = rotation(p1, pm, rotate_ang - 90) rp1 = _rotation(p1, pm, rotate_ang - 90)
rp2 = rotation(p2, pm, rotate_ang - 90) rp2 = _rotation(p2, pm, rotate_ang - 90)
else: else:
rp1, rp2 = p1, p2 rp1, rp2 = p1, p2
# adjust short edges (judged by a threshold θ) according to the list and angles. # adjust short edges (judged by a threshold θ) according to the list and angles.
@ -189,11 +181,11 @@ def _fine(contour, W):
rotate_ang = md_used_list[-1] - azimuth rotate_ang = md_used_list[-1] - azimuth
abs_rotate_ang = abs(rotate_ang) abs_rotate_ang = abs(rotate_ang)
if abs_rotate_ang < THETA or abs_rotate_ang > (180 - THETA): if abs_rotate_ang < THETA or abs_rotate_ang > (180 - THETA):
rp1 = rotation(p1, pm, rotate_ang) rp1 = _rotation(p1, pm, rotate_ang)
rp2 = rotation(p2, pm, rotate_ang) rp2 = _rotation(p2, pm, rotate_ang)
else: else:
rp1 = rotation(p1, pm, rotate_ang - 90) rp1 = _rotation(p1, pm, rotate_ang - 90)
rp2 = rotation(p2, pm, rotate_ang - 90) rp2 = _rotation(p2, pm, rotate_ang - 90)
# contour_by_lines.extend([rp1, rp2]) # contour_by_lines.extend([rp1, rp2])
contour_by_lines.append([rp1[0], rp2[0]]) contour_by_lines.append([rp1[0], rp2[0]])
correct_points = np.array(contour_by_lines) correct_points = np.array(contour_by_lines)
@ -208,35 +200,35 @@ def _fine(contour, W):
cur_edge_p2 = correct_points[idx][1] cur_edge_p2 = correct_points[idx][1]
next_edge_p1 = correct_points[next_idx][0] next_edge_p1 = correct_points[next_idx][0]
next_edge_p2 = correct_points[next_idx][1] next_edge_p2 = correct_points[next_idx][1]
L1 = line(cur_edge_p1, cur_edge_p2) L1 = _line(cur_edge_p1, cur_edge_p2)
L2 = line(next_edge_p1, next_edge_p2) L2 = _line(next_edge_p1, next_edge_p2)
A1 = calc_azimuth([cur_edge_p1], [cur_edge_p2]) A1 = _calc_azimuth([cur_edge_p1], [cur_edge_p2])
A2 = calc_azimuth([next_edge_p1], [next_edge_p2]) A2 = _calc_azimuth([next_edge_p1], [next_edge_p2])
dif_azi = abs(A1 - A2) dif_azi = abs(A1 - A2)
# find intersection point if not parallel # find intersection point if not parallel
if (90 - DELTA) < dif_azi < (90 + DELTA): if (90 - DELTA) < dif_azi < (90 + DELTA):
point_intersection = intersection(L1, L2) point_intersection = _intersection(L1, L2)
if point_intersection is not None: if point_intersection is not None:
final_points.append(point_intersection) final_points.append(point_intersection)
# move or add lines when parallel # move or add lines when parallel
elif dif_azi < 1e-6: elif dif_azi < 1e-6:
marg = calc_distance_between_lines(L1, L2) marg = _calc_distance_between_lines(L1, L2)
if marg < D: if marg < D:
# move # move
point_move = calc_project_in_line(next_edge_p1, cur_edge_p1, point_move = _calc_project_in_line(next_edge_p1, cur_edge_p1,
cur_edge_p2) cur_edge_p2)
final_points.append(point_move) final_points.append(point_move)
# update next # update next
correct_points[next_idx][0] = point_move correct_points[next_idx][0] = point_move
correct_points[next_idx][1] = calc_project_in_line( correct_points[next_idx][1] = _calc_project_in_line(
next_edge_p2, cur_edge_p1, cur_edge_p2) next_edge_p2, cur_edge_p1, cur_edge_p2)
else: else:
# add line # add line
add_mid_point = (cur_edge_p2 + next_edge_p1) / 2 add_mid_point = (cur_edge_p2 + next_edge_p1) / 2
rp1 = calc_project_in_line(add_mid_point, cur_edge_p1, rp1 = _calc_project_in_line(add_mid_point, cur_edge_p1,
cur_edge_p2) cur_edge_p2)
rp2 = calc_project_in_line(add_mid_point, next_edge_p1, rp2 = _calc_project_in_line(add_mid_point, next_edge_p1,
next_edge_p2) next_edge_p2)
final_points.extend([rp1, rp2]) final_points.extend([rp1, rp2])
else: else:
final_points.extend( final_points.extend(
@ -262,3 +254,96 @@ def _fill(img, coarse_conts):
else: else:
cv2.fillPoly(result, [contour.astype(np.int32)], (255, 255, 255)) cv2.fillPoly(result, [contour.astype(np.int32)], (255, 255, 255))
return result return result
def _calc_angle(p1, vertex, p2):
x1, y1 = p1[0]
xv, yv = vertex[0]
x2, y2 = p2[0]
a = ((xv - x2) * (xv - x2) + (yv - y2) * (yv - y2))**0.5
b = ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))**0.5
c = ((x1 - xv) * (x1 - xv) + (y1 - yv) * (y1 - yv))**0.5
return math.degrees(math.acos((b**2 - a**2 - c**2) / (-2 * a * c)))
def _calc_azimuth(p1, p2):
x1, y1 = p1[0]
x2, y2 = p2[0]
if y1 == y2:
return 0.0
if x1 == x2:
return 90.0
elif x1 < x2:
if y1 < y2:
ang = math.atan((y2 - y1) / (x2 - x1))
return math.degrees(ang)
else:
ang = math.atan((y1 - y2) / (x2 - x1))
return 180 - math.degrees(ang)
else: # x1 > x2
if y1 < y2:
ang = math.atan((y2 - y1) / (x1 - x2))
return 180 - math.degrees(ang)
else:
ang = math.atan((y1 - y2) / (x1 - x2))
return math.degrees(ang)
def _rotation(point, center, angle):
if angle == 0:
return point
x, y = point[0]
cx, cy = center[0]
radian = math.radians(abs(angle))
if angle > 0: # clockwise
rx = (x - cx) * math.cos(radian) - (y - cy) * math.sin(radian) + cx
ry = (x - cx) * math.sin(radian) + (y - cy) * math.cos(radian) + cy
else:
rx = (x - cx) * math.cos(radian) + (y - cy) * math.sin(radian) + cx
ry = (y - cy) * math.cos(radian) - (x - cx) * math.sin(radian) + cy
return np.array([[rx, ry]])
def _line(p1, p2):
A = (p1[1] - p2[1])
B = (p2[0] - p1[0])
C = (p1[0] * p2[1] - p2[0] * p1[1])
return A, B, -C
def _intersection(L1, L2):
D = L1[0] * L2[1] - L1[1] * L2[0]
Dx = L1[2] * L2[1] - L1[1] * L2[2]
Dy = L1[0] * L2[2] - L1[2] * L2[0]
if D != 0:
x = Dx / D
y = Dy / D
return np.array([[x, y]])
else:
return None
def _calc_distance_between_lines(L1, L2):
eps = 1e-16
A1, _, C1 = L1
A2, B2, C2 = L2
new_C1 = C1 / (A1 + eps)
new_A2 = 1
new_B2 = B2 / (A2 + eps)
new_C2 = C2 / (A2 + eps)
dist = (np.abs(new_C1 - new_C2)) / (
np.sqrt(new_A2 * new_A2 + new_B2 * new_B2) + eps)
return dist
def _calc_project_in_line(point, line_point1, line_point2):
eps = 1e-16
m, n = point
x1, y1 = line_point1
x2, y2 = line_point2
F = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
x = (m * (x2 - x1) * (x2 - x1) + n * (y2 - y1) * (x2 - x1) +
(x1 * y2 - x2 * y1) * (y2 - y1)) / (F + eps)
y = (m * (x2 - x1) * (y2 - y1) + n * (y2 - y1) * (y2 - y1) +
(x2 * y1 - x1 * y2) * (x2 - x1)) / (F + eps)
return np.array([[x, y]])

@ -13,101 +13,22 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import math import cv2
def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float: def prepro_mask(mask: np.ndarray):
return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2)))) mask_shape = mask.shape
if len(mask_shape) != 2:
mask = mask[..., 0]
def calc_angle(p1: np.ndarray, vertex: np.ndarray, p2: np.ndarray) -> float: mask = mask.astype("uint8")
x1, y1 = p1[0] mask = cv2.medianBlur(mask, 5)
xv, yv = vertex[0] class_num = len(np.unique(mask))
x2, y2 = p2[0] if class_num != 2:
a = ((xv - x2) * (xv - x2) + (yv - y2) * (yv - y2))**0.5 _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY |
b = ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2))**0.5 cv2.THRESH_OTSU)
c = ((x1 - xv) * (x1 - xv) + (y1 - yv) * (y1 - yv))**0.5 mask = np.clip(mask, 0, 1).astype("uint8") # 0-255 / 0-1 -> 0-1
return math.degrees(math.acos((b**2 - a**2 - c**2) / (-2 * a * c))) return mask
def calc_azimuth(p1: np.ndarray, p2: np.ndarray) -> float:
x1, y1 = p1[0]
x2, y2 = p2[0]
if y1 == y2:
return 0.0
if x1 == x2:
return 90.0
elif x1 < x2:
if y1 < y2:
ang = math.atan((y2 - y1) / (x2 - x1))
return math.degrees(ang)
else:
ang = math.atan((y1 - y2) / (x2 - x1))
return 180 - math.degrees(ang)
else: # x1 > x2
if y1 < y2:
ang = math.atan((y2 - y1) / (x1 - x2))
return 180 - math.degrees(ang)
else:
ang = math.atan((y1 - y2) / (x1 - x2))
return math.degrees(ang)
def rotation(point: np.ndarray, center: np.ndarray, angle: float) -> np.ndarray:
if angle == 0:
return point
x, y = point[0]
cx, cy = center[0]
radian = math.radians(abs(angle))
if angle > 0: # clockwise
rx = (x - cx) * math.cos(radian) - (y - cy) * math.sin(radian) + cx
ry = (x - cx) * math.sin(radian) + (y - cy) * math.cos(radian) + cy
else:
rx = (x - cx) * math.cos(radian) + (y - cy) * math.sin(radian) + cx
ry = (y - cy) * math.cos(radian) - (x - cx) * math.sin(radian) + cy
return np.array([[rx, ry]])
def line(p1, p2): def calc_distance(p1: np.ndarray, p2: np.ndarray) -> float:
A = (p1[1] - p2[1]) return float(np.sqrt(np.sum(np.power((p1[0] - p2[0]), 2))))
B = (p2[0] - p1[0])
C = (p1[0] * p2[1] - p2[0] * p1[1])
return A, B, -C
def intersection(L1, L2):
D = L1[0] * L2[1] - L1[1] * L2[0]
Dx = L1[2] * L2[1] - L1[1] * L2[2]
Dy = L1[0] * L2[2] - L1[2] * L2[0]
if D != 0:
x = Dx / D
y = Dy / D
return np.array([[x, y]])
else:
return None
def calc_distance_between_lines(L1, L2):
eps = 1e-16
A1, _, C1 = L1
A2, B2, C2 = L2
new_C1 = C1 / (A1 + eps)
new_A2 = 1
new_B2 = B2 / (A2 + eps)
new_C2 = C2 / (A2 + eps)
dist = (np.abs(new_C1 - new_C2)) / (
np.sqrt(new_A2 * new_A2 + new_B2 * new_B2) + eps)
return dist
def calc_project_in_line(point, line_point1, line_point2):
eps = 1e-16
m, n = point
x1, y1 = line_point1
x2, y2 = line_point2
F = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
x = (m * (x2 - x1) * (x2 - x1) + n * (y2 - y1) * (x2 - x1) +
(x1 * y2 - x2 * y1) * (y2 - y1)) / (F + eps)
y = (m * (x2 - x1) * (y2 - y1) + n * (y2 - y1) * (y2 - y1) +
(x2 * y1 - x1 * y2) * (x2 - x1)) / (F + eps)
return np.array([[x, y]])

Loading…
Cancel
Save