|
|
|
@ -34,19 +34,21 @@ def linear_assignment(cost_matrix, thresh, use_lap=True): |
|
|
|
|
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) |
|
|
|
|
|
|
|
|
|
if use_lap: |
|
|
|
|
# https://github.com/gatagat/lap |
|
|
|
|
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) |
|
|
|
|
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] |
|
|
|
|
unmatched_a = np.where(x < 0)[0] |
|
|
|
|
unmatched_b = np.where(y < 0)[0] |
|
|
|
|
else: |
|
|
|
|
# Scipy linear sum assignment is NOT working correctly, DO NOT USE |
|
|
|
|
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x |
|
|
|
|
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh]) |
|
|
|
|
unmatched = np.ones(cost_matrix.shape) |
|
|
|
|
for i, xi in matches: |
|
|
|
|
unmatched[i, xi] = 0.0 |
|
|
|
|
unmatched_a = np.where(unmatched.all(1))[0] |
|
|
|
|
unmatched_b = np.where(unmatched.all(0))[0] |
|
|
|
|
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html |
|
|
|
|
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y |
|
|
|
|
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) |
|
|
|
|
if len(matches) == 0: |
|
|
|
|
unmatched_a = list(np.arange(cost_matrix.shape[0])) |
|
|
|
|
unmatched_b = list(np.arange(cost_matrix.shape[1])) |
|
|
|
|
else: |
|
|
|
|
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) |
|
|
|
|
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) |
|
|
|
|
|
|
|
|
|
return matches, unmatched_a, unmatched_b |
|
|
|
|
|
|
|
|
|