#!/usr/bin/env python3.11

#----K-means Clustering for Customer Segmentation Analysis
#----Assume you work in the credit card department of a bank. Your job is # to understand the 
#----behaviors of the customers (credit card holders) and improve marketing strategies, you need
#----to categorize the customers based on their characteristics (income, age, buying behavior, etc).
#----Find the clusters/groups that contain valuable customers: e.g., high income but low annual
#----spend
#----The dataset consists of annual income (divided by $1000) and total spend (divided by $1000) of
#----~300 people for a period of one year.  https://github.com/sowmyacr/kmeans_cluster
#----Each row is a feature vector of a customer
#--------------------------------------------------------------------------------------------------
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
#Plot styling
import seaborn as sns; sns.set()  
from sklearn.cluster import KMeans
#--------------------------------------------------------------------------------------------------
def get_data():

#----Read the dataset as a panda
    dataset = pd.read_csv('KMeans-Customers.csv') 
#----Show the first five rows of the table dataset, and the number of samples
    # print(f"{dataset.head()} and there are {len(dataset)} rows")
# type(dataset) # it is a dataframe of pandas to represent a table
#----Descriptive statistics of the dataset
    # print(f"Some statistics {dataset.describe()}")

#----Show the distribution/histogram of income and spending
    figure, plot_area = plt.subplots()
#----Set alpha=0.5 to make the figure to be transparent
    plot_area.hist(dataset['income'], color='red', alpha=0.5)
    plot_area.hist(dataset['spend'], color='blue', alpha=0.5)
    plot_area.set_title('spend(blue), income(red)',fontsize=16)
    plt.draw()
    plt.show()
    plt.close()
#----Each dot in the figure corresponds to a customer. Each dot shows the feature vector of a
#----customer. There are 303 dots/customers/vectors
    figure, plot_area = plt.subplots()
    plot_area.scatter(dataset['income'], dataset['spend'], c='blue', s=10)
    plot_area.set_title('data points', fontsize=20)
    plot_area.set_xlabel('income', fontsize=20)
    plot_area.set_ylabel('spend', fontsize=20)
    plt.draw()
    plt.show()
    plt.close()

#----The functions/methods in sk-learn work well on numpy arrays. Some functions crash if the 
#----input data type is pandas dataframe. Convert dataframe to numpy array 
    data = dataset.values 
    # print('type is', type(data))
    # print('shape is', data.shape)
    # print(f"{data}")

    return data
#--------------------------------------------------------------------------------------------------
def plot_clusters(number_of_clusters,data,centers,label):

#----Show the centers together with the data points
    color = ['red', 'blue', 'green', 'purple', 'cyan', 'orange']
    figure, plot_area = plt.subplots(figsize=(10,6))
    for cluster_number in range(0, number_of_clusters):
        plot_area.scatter(data[label==cluster_number, 0], data[label==cluster_number, 1], \
c=color[cluster_number], s=20, label='Cluster '+str(cluster_number))
        plot_area.plot(centers[cluster_number,0], centers[cluster_number,1], \
c=color[cluster_number], marker='s', markersize=10)
    plot_area.set_title('Clusters', fontsize=20)
    plot_area.set_xlabel('income', fontsize=20)
    plot_area.set_ylabel('spend', fontsize=20)
    plot_area.set_xlim(100, 500)
    plot_area.legend(fontsize=20)
    plot_area.set_aspect('equal')
    plt.draw()
    plt.show()
    plt.close()
#--------------------------------------------------------------------------------------------------
def do_cluster(number_of_clusters,data):

#----Define a machine learning model using Kmeans algorithm
    model = KMeans(n_clusters=number_of_clusters, random_state=0) 
#----Fit the model to data
    model.fit(data) 
#----Get the cluster label of each customer
    label = model.predict(data) 
#----303 labels for 303 customers in the table
    # print(f"{label}")
#----Each cluster has a center
    centers = model.cluster_centers_
    # print(f"{centers}")

#----Plot the clusters
    plot_clusters(number_of_clusters,data,centers,label)
#--------------------------------------------------------------------------------------------------
def plot_losses(data):

#----Using the elbow method to find the best number of clusters
#----Apply k-means with different number of clusters, and study the relationship between 
#----error/loss of clustering and the number of clusters.
    error_list = []
    for K in range(1,21):
        model = KMeans(n_clusters=K, random_state=0)
        model.fit(data)
        error_list.append(model.inertia_)
    figure, plot_area = plt.subplots(figsize=(10,6))
    plot_area.plot(range(1,21), error_list)
    plot_area.set_title('Elbow Method', fontsize=20)
    plot_area.set_xlabel('Number of clusters', fontsize=20)
    plot_area.set_ylabel('clustering Error/Loss', fontsize=20)
    plot_area.set_xticks(range(1,21));
    plt.draw()
    plt.show()
    plt.close()
#--------------------------------------------------------------------------------------------------
def main():

    all_samples = get_data()
    do_cluster(3,all_samples)
#----Cluster-0: the customers have low annual income
#----Cluster-1: the customers have medium annual income
#----Cluster-2: the customers have high annual income
#----Elbow Method to choose the number clusters: Choose the number of clusters so that 
#----adding another new cluster doesn't significantly decrease clustering error/loss.
#----It is subjective because it depends on the judgement of the user. 
    plot_losses(all_samples)
#----The user (me) thinks 6 clusters might be enough
    do_cluster(6,all_samples)

    plt.close('all')
    input("Press Enter to end")
#--------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    main()
#--------------------------------------------------------------------------------------------------
