#!/usr/bin/env python3.11

import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
#--------------------------------------------------------------------------------------------------
def get_data():

#----1. Prepare the Training Data
#----Features: [Weight in grams, Texture scale 1-10]
#----Labels: 0 for Apple, 1 for Orange
    features_train = [
        [140, 8],  #----Apple 1
        [150, 7],  #----Apple 2
        [130, 9],  #----Apple 3
        [200, 2],  #----Orange 1
        [220, 3],  #----Orange 2
        [210, 1],  #----Orange 3
    ]
    classes_train = ['Apple','Apple','Apple', 'Orange', 'Orange', 'Orange']

    features_test = [
        [170, 8],  #----Apple
        [190, 4]   #----Orange
    ]
    classes_test = ['Apple', 'Orange']

#----Give meaningful names to features and classes. Classes must be in order of their labels.
    feature_names = ['weight in grams','texture']
    class_names = ['Apple','Orange']

    return features_train,classes_train,features_test,classes_test,feature_names,class_names
#--------------------------------------------------------------------------------------------------
def get_data_from_csv(file_name):

    data = np.genfromtxt(file_name, delimiter=',', names=True, filling_values=0, dtype=None, \
ndmin=1)
    # print(f"Extracting from {file_name} and got {data}")
#----Extract all except the last column name as feature names
    feature_names = list(data.dtype.names[:-1])
    print(f"Feature names {feature_names}")
#----Extract unique values from last column as class names
    class_names = np.unique(data[data.dtype.names[-1]]).tolist()
    # print(f"Class names {class_names}")

#----Extract all except the last column as features
    features = data[feature_names].tolist()
    # print(f"Features are {features}")
#----Extract last column as classes
    np_classes = list(data[data.dtype.names[-1]])
    classes = [str(item) for item in np_classes]
    # print(f"All classes {classes}")

    return features,classes,feature_names,class_names
#--------------------------------------------------------------------------------------------------
def get_predictions(K,features_train,classes_train,features_test):

#----Initialize the K-NN Classifier
    knn = KNeighborsClassifier(n_neighbors=K)
#----Train the model
    knn.fit(features_train, classes_train)
#----Predict for the test data
    predictions = knn.predict(features_test)

    return knn,predictions
#--------------------------------------------------------------------------------------------------
def print_predictions(knn,predictions,features_test,classes_test):

#----Check the probability (How sure is the model?)
    probability = knn.predict_proba(features_test)
    confidences = np.max(probability, axis=1)
#----Output the result for each test data
    for index, value in enumerate(classes_test):
        print(f"The {classes_test[index]} is classified as: {predictions[index]}")
        print(f"Confidence is {confidences[index]}")

#----Compute the accuracy
    accuracy = accuracy_score(classes_test,predictions)
    print(f"The accuracy is {accuracy}")
#--------------------------------------------------------------------------------------------------
def main():

#----Get the training and test data
    # features_train,classes_train,features_test,classes_test,feature_names,class_names = \
# get_data()
    features_train,classes_train,feature_names,class_names = \
get_data_from_csv("KNN-fruit_train.csv");
    features_test,classes_test,*_ = get_data_from_csv("KNN-fruit_test.csv");

#----We'll look at the 3 nearest neighbors (k=3)
    knn,predictions = get_predictions(3,features_train,classes_train,features_test)

#----Report results
    print_predictions(knn,predictions,features_test,classes_test)
#--------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    main()
#--------------------------------------------------------------------------------------------------
