#!/usr/bin/env python3.11

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans

#----Load the digit dataset (taken from https://huggingface.co/datasets/ylecun/mnist)
features_train = np.load('KMeans-Digits-Train.npy')
features_test = np.load('KMeans-Digits-Test.npy')
#----Plot each 1000th image
plt.ion()
figure, plot_area = plt.subplots(figsize=(3, 3))
for image_number in range(0, features_train.shape[0], 1000):
    plot_area.imshow(features_train[image_number,:].reshape(28,28), cmap='gray')
    plot_area.set_title('label: ' + str(int(features_test[image_number])), fontsize=16)
    plot_area.axis('off')
    plt.pause(0.1)    
    plt.show()
plt.close()
plt.ioff()

#----Apply K-means to divide the digit images into 10 clusters and find the cluster centers
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(features_train[::30]) #slow
#----Faster
# kmeans = MiniBatchKMeans(n_clusters=10, batch_size=100, random_state=0)
# clusters = MBkmeans.fit_predict(features_train)
# print(f"{kmeans.cluster_centers_.shape}")

#----Reshape into 10 28x28 images
centers = kmeans.cluster_centers_.reshape(10, 28, 28)
#----Plot the 10 shapes, 2 rows, 5 per row
figure, plot_area = plt.subplots(2, 5, figsize=(8, 3))
for subplot, center in zip(plot_area.flat, centers):
    subplot.set(xticks=[], yticks=[])
    subplot.imshow(center, interpolation='nearest', cmap=plt.cm.binary)
    plt.draw()
plt.show()
plt.close()

plt.close('all')
input("Press Enter to end")
