You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
3.3 KiB
98 lines
3.3 KiB
import argparse |
|
import sys |
|
import os |
|
import time |
|
import numpy as np |
|
from sklearn.cluster import KMeans |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def k_means(K, data, max_iter, n_jobs, image_file): |
|
X = np.array(data) |
|
np.random.shuffle(X) |
|
begin = time.time() |
|
print 'Running kmeans' |
|
kmeans = KMeans(n_clusters=K, max_iter=max_iter, n_jobs=n_jobs, verbose=1).fit(X) |
|
print 'K-Means took {} seconds to complete'.format(time.time()-begin) |
|
step_size = 0.2 |
|
xmin, xmax = X[:, 0].min()-1, X[:, 0].max()+1 |
|
ymin, ymax = X[:, 1].min()-1, X[:, 1].max()+1 |
|
xx, yy = np.meshgrid(np.arange(xmin, xmax, step_size), np.arange(ymin, ymax, step_size)) |
|
preds = kmeans.predict(np.c_[xx.ravel(), yy.ravel()]) |
|
preds = preds.reshape(xx.shape) |
|
|
|
plt.figure() |
|
plt.clf() |
|
plt.imshow(preds, interpolation='nearest', extent=(xx.min(), xx.max(), yy.min(), yy.max()), cmap=plt.cm.Paired, aspect='auto', origin='lower') |
|
plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2) |
|
centroids = kmeans.cluster_centers_ |
|
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=169, linewidths=5, color='r', zorder=10) |
|
plt.title("Anchor shapes generated using K-Means") |
|
plt.xlim(xmin, xmax) |
|
plt.ylim(ymin, ymax) |
|
print 'Mean centroids are:' |
|
for i, center in enumerate(centroids): |
|
print '{}: {}, {}'.format(i, center[0], center[1]) |
|
# plt.xticks(()) |
|
# plt.yticks(()) |
|
plt.show() |
|
|
|
def pre_process(directory, data_list): |
|
if not os.path.exists(directory): |
|
print "Path {} doesn't exist".format(directory) |
|
return |
|
files = os.listdir(directory) |
|
print 'Loading data...' |
|
for i, f in enumerate(files): |
|
# Progress bar |
|
sys.stdout.write('\r') |
|
percentage = (i+1.0) / len(files) |
|
progress = int(percentage * 30) |
|
bar = [progress*'=', ' '*(29-progress), percentage*100] |
|
sys.stdout.write('[{}>{}] {:.0f}%'.format(*bar)) |
|
sys.stdout.flush() |
|
|
|
with open(directory+"/"+f, 'r') as ann: |
|
l = ann.readline() |
|
l = l.rstrip() |
|
l = l.split(' ') |
|
l = [float(i) for i in l] |
|
if len(l) % 5 != 0: |
|
sys.stderr.write('File {} contains incorrect number of annotations'.format(f)) |
|
return |
|
num_objs = len(l) / 5 |
|
for obj in range(num_objs): |
|
xmin = l[obj * 5 + 0] |
|
ymin = l[obj * 5 + 1] |
|
xmax = l[obj * 5 + 2] |
|
ymax = l[obj * 5 + 3] |
|
w = xmax - xmin |
|
h = ymax - ymin |
|
data_list.append([w, h]) |
|
if w > 1000 or h > 1000: |
|
sys.stdout.write("[{}, {}]".format(w, h)) |
|
sys.stdout.write('\nProcessed {} files containing {} objects'.format(len(files), len(data_list))) |
|
return data_list |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser("Parse hyperparameters") |
|
parser.add_argument("clusters", help="Number of clusters", type=int) |
|
parser.add_argument("dir", help="Directory containing annotations") |
|
parser.add_argument("image_file", help="File to generate the final cluster of image") |
|
parser.add_argument('-jobs', help="Number of jobs for parallel computation", default=1) |
|
parser.add_argument('-iter', help="Max Iterations to run algorithm for", default=1000) |
|
|
|
p = parser.parse_args(sys.argv[1:]) |
|
K = p.clusters |
|
directory = p.dir |
|
data_list = [] |
|
pre_process(directory, data_list ) |
|
sys.stdout.write('\nDone collecting data\n') |
|
k_means(K, data_list, int(p.iter), int(p.jobs), p.image_file) |
|
print 'Done !' |
|
|
|
if __name__=='__main__': |
|
try: |
|
main() |
|
except Exception as E: |
|
print E
|
|
|