#!/usr/bin/env python3.11

#----Eliminating warnings from scikit-learn 
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
import numpy as np
import time
import matplotlib.pyplot as plt
#--------------------------------------------------------------------------------------------------
def get_data(number_of_samples):

#----Generate some samples
#----Create a random number generator
    random_number_generator = np.random.RandomState()
#----Generate some data samples (three groups) in 2D space
#----group1
    sigma=2 # try 0.5 (isolated clusters) and 2 (overlapping clusters)
    cluster_1 = sigma*random_number_generator.randn(number_of_samples,2)
    cluster_1 += [4, 4] #center of group1
#----group2
    cluster_2 = sigma*random_number_generator.randn(number_of_samples,2)
    cluster_2 += [1, 1] #center of group2
#----group3
    cluster_3 = sigma*random_number_generator.randn(number_of_samples,2)
    cluster_3 += [4, 1] #center of group3
#----Put data in one 2D array
    all_samples = np.concatenate([cluster_1, cluster_2, cluster_3], axis=0)
    figure, plot_area = plt.subplots(figsize=(6,6))
    plot_area.plot(cluster_1[:,0], cluster_1[:,1], 'ro')
    plot_area.plot(cluster_2[:,0], cluster_2[:,1], 'go')
    plot_area.plot(cluster_3[:,0], cluster_3[:,1], 'bo')
    plot_area.set_aspect('equal')
    plot_area.set_title('data points', fontsize=16)
    plt.draw()
    plt.show()
    plt.close()

    return all_samples
#--------------------------------------------------------------------------------------------------
def get_initial_clusters(K,all_samples):

#----Create a random number generator
    random_number_generator = np.random.RandomState()
#----Random centers in the current iteration
    initial_clusters = 6*random_number_generator.rand(K,2) 
#----Show initial centers and data
    figure, plot_area = plt.subplots(figsize=(6,6))
    plot_area.plot(all_samples[:,0], all_samples[:,1], 'ko')
    colors=['rs', 'bs', 'gs']
    for cluster_number in range(0, K):
        plot_area.plot(initial_clusters[cluster_number,0], initial_clusters[cluster_number,1], colors[cluster_number], markersize=10)
    plot_area.set_aspect('equal')
    plot_area.set_title('Initialization', fontsize=20)
    plt.draw()
    plt.show()
    plt.close()

    return initial_clusters
#--------------------------------------------------------------------------------------------------
#----Update the centers and the assignment (cluster lables of data points) in a loop
def iterative_cluster(K,all_samples,initial_cluster_centers):

#----Interactive mode for plotting in a loop
    plt.ion()
    cluster_centers = initial_cluster_centers.copy()
    cluster_centers_old = initial_cluster_centers.copy()
#----Time recorder
    tic = time.time() 
    loss_list=[]
    figure, plot_area = plt.subplots(figsize=(6,6))
    all_samples_norm = np.sum(all_samples**2, axis = 1).reshape(-1,1)
    interation=0
    while True:   
        if interation > 0:
            difference = (cluster_centers-cluster_centers_old)**2
            difference = difference.sum()
#----Converged if the difference is very small
            if difference < 1e-10:
                break    
#----Record the previous centers  
        cluster_centers_old = cluster_centers.copy()    
#----Given the centers, obtain the label / assignment
        cluster_centers_norm = np.sum(cluster_centers**2, axis = 1).reshape(1,K)
        XdotCt = np.dot(all_samples, cluster_centers.T)    
        dist_sq = cluster_centers_norm - 2*XdotCt
        # dist_sq = all_samples_norm + (cluster_centers_norm - 2*XdotCt)
        label = np.argmin(dist_sq, axis=1)
#----label[n] is the cluster label of sample-n
        loss = 0    
        for cluster_number in range(0, K):
#----Find the data samples in cluster-cluster_number
            samples_in_cluster=np.where(label==cluster_number)[0]
#----If no samples are assigned to this cluster re-initialize the center
            if samples_in_cluster.shape[0] == 0:
                cluster_centers[cluster_number,:] = 6*rng.rand(1,2)
            else:
#----Update cluster center
                Xc = all_samples[samples_in_cluster,:]
                cluster_centers[cluster_number,:] = np.mean(Xc, axis=0)
                loss += np.sum((Xc - cluster_centers[cluster_number,:])**2)
        loss_list.append(loss)
        interation+=1
    
        plot_area.clear()
#----Print the samples for each cluster
        colors=['ro', 'bo', 'go']
        for cluster_number in range(0, K):
            samples_in_cluster = np.where(label==cluster_number)[0]
#----Plot cluster is it has some members
            if samples_in_cluster.shape[0] > 0:
                plot_area.plot(all_samples[samples_in_cluster,0], \
all_samples[samples_in_cluster,1], colors[cluster_number])
        colors=['rs', 'bs', 'gs']
        for cluster_number in range(0, K):
            plot_area.plot(cluster_centers[cluster_number,0], cluster_centers[cluster_number,1], \
colors[cluster_number], markersize=16)   
        plot_area.set_title('interation='+str(interation)+', loss='+f"{loss:.3f}", fontsize=16)
        plot_area.set_aspect('equal')
        plt.pause(1)
        plt.show()
    input("Press Enter to continue")
    plt.close()
    plt.ioff()

    toc = time.time()
    print('Elapsed time is %f seconds \n' % float(toc - tic))  

    return loss_list
#--------------------------------------------------------------------------------------------------
#----Plot the loss values over the iterations
def plot_losses(loss_list):

    figure, plot_area = plt.subplots()
    plot_area.plot(loss_list, '.-')
    plot_area.set_title('Loss vs Iteration', fontsize=16)
    plot_area.set_xlabel('iteration', fontsize=16)
    plot_area.set_ylabel('Loss', fontsize=16)
    plt.draw()
    plt.show()
    plt.close()
#--------------------------------------------------------------------------------------------------
def main():

#----K is the number of clusters
    K = 3
    all_samples = get_data(100)
    initial_clusters = get_initial_clusters(K,all_samples)
    loss_list = iterative_cluster(K,all_samples,initial_clusters)
    plot_losses(loss_list)

    plt.close('all')
    input("Press Enter to end")
#--------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    main()
#--------------------------------------------------------------------------------------------------
