Repository for OpenCV's extra modules
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

98 lines
3.3 KiB

import argparse
import sys
import os
import time
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
def k_means(K, data, max_iter, n_jobs, image_file):
X = np.array(data)
np.random.shuffle(X)
begin = time.time()
print 'Running kmeans'
kmeans = KMeans(n_clusters=K, max_iter=max_iter, n_jobs=n_jobs, verbose=1).fit(X)
print 'K-Means took {} seconds to complete'.format(time.time()-begin)
step_size = 0.2
xmin, xmax = X[:, 0].min()-1, X[:, 0].max()+1
ymin, ymax = X[:, 1].min()-1, X[:, 1].max()+1
xx, yy = np.meshgrid(np.arange(xmin, xmax, step_size), np.arange(ymin, ymax, step_size))
preds = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
preds = preds.reshape(xx.shape)
plt.figure()
plt.clf()
plt.imshow(preds, interpolation='nearest', extent=(xx.min(), xx.max(), yy.min(), yy.max()), cmap=plt.cm.Paired, aspect='auto', origin='lower')
plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)
centroids = kmeans.cluster_centers_
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=169, linewidths=5, color='r', zorder=10)
plt.title("Anchor shapes generated using K-Means")
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
print 'Mean centroids are:'
for i, center in enumerate(centroids):
print '{}: {}, {}'.format(i, center[0], center[1])
# plt.xticks(())
# plt.yticks(())
plt.show()
def pre_process(directory, data_list):
if not os.path.exists(directory):
print "Path {} doesn't exist".format(directory)
return
files = os.listdir(directory)
print 'Loading data...'
for i, f in enumerate(files):
# Progress bar
sys.stdout.write('\r')
percentage = (i+1.0) / len(files)
progress = int(percentage * 30)
bar = [progress*'=', ' '*(29-progress), percentage*100]
sys.stdout.write('[{}>{}] {:.0f}%'.format(*bar))
sys.stdout.flush()
with open(directory+"/"+f, 'r') as ann:
l = ann.readline()
l = l.rstrip()
l = l.split(' ')
l = [float(i) for i in l]
if len(l) % 5 != 0:
sys.stderr.write('File {} contains incorrect number of annotations'.format(f))
return
num_objs = len(l) / 5
for obj in range(num_objs):
xmin = l[obj * 5 + 0]
ymin = l[obj * 5 + 1]
xmax = l[obj * 5 + 2]
ymax = l[obj * 5 + 3]
w = xmax - xmin
h = ymax - ymin
data_list.append([w, h])
if w > 1000 or h > 1000:
sys.stdout.write("[{}, {}]".format(w, h))
sys.stdout.write('\nProcessed {} files containing {} objects'.format(len(files), len(data_list)))
return data_list
def main():
parser = argparse.ArgumentParser("Parse hyperparameters")
parser.add_argument("clusters", help="Number of clusters", type=int)
parser.add_argument("dir", help="Directory containing annotations")
parser.add_argument("image_file", help="File to generate the final cluster of image")
parser.add_argument('-jobs', help="Number of jobs for parallel computation", default=1)
parser.add_argument('-iter', help="Max Iterations to run algorithm for", default=1000)
p = parser.parse_args(sys.argv[1:])
K = p.clusters
directory = p.dir
data_list = []
pre_process(directory, data_list )
sys.stdout.write('\nDone collecting data\n')
k_means(K, data_list, int(p.iter), int(p.jobs), p.image_file)
print 'Done !'
if __name__=='__main__':
try:
main()
except Exception as E:
print E