#!/usr/bin/env python3.11

import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk import FreqDist
from nltk.collocations import *
import collections
import string
from urllib import request
from IPython.display import SVG, display
from nltk.tree import Tree
from nltk.text import ConcordanceIndex

nltk.download("stopwords")
nltk.download("averaged_perceptron_tagger")
nltk.download('averaged_perceptron_tagger_eng')
nltk.download("maxent_ne_chunker")
nltk.download('maxent_ne_chunker_tab')
nltk.download("words")
nltk.download('punkt_tab')
nltk.download('wordnet')
nltk.download('tagsets_json')
#--------------------------------------------------------------------------------------------------
def get_text(which):

    if which == "sample":
        text = "\
Winston Churchill learned rapidly because his first training was in how to learn. \
And the first lessons of all were the basic trust that he could learn. \
It's shocking that 35% of people in America hate to learn, \
and how many more believe learning to be impossible."
    elif which == "corpus":
        URL = "https://www.gutenberg.org/cache/epub/729/pg729.txt"
        print(f"Reading the book at {URL}")
        text = request.urlopen(URL).read().decode('utf8')
    else:
        text = ""

    # print(f"The text is:\n{text}")
    return text
#--------------------------------------------------------------------------------------------------
def clean_text(text):

    sentences = sent_tokenize(text)
    # print(f"\nThe sentences are:\n{sentences}")

    words = word_tokenize(text)
    # print(f"\nThe words are:\n{words}")

    stop_words = set(stopwords.words("english"))
    depunctuator = str.maketrans('', '', string.punctuation)
    filtered_words = [
        word for word in words if word.casefold() not in stop_words and word[0].isalnum()
    ]
    # print(f"\nThe filtered words are:\n{filtered_words}")

    lemmatizer = WordNetLemmatizer()
    lemmatized_words = [lemmatizer.lemmatize(Word) for Word in filtered_words]
    # print(f"\nThe lemmatized filtered words are:\n{lemmatized_words}")

    return sentences,words,filtered_words,lemmatized_words
#--------------------------------------------------------------------------------------------------
def statistical_analysis(words,top_size,do_plots):

#----Make into a corpus for analysis
    corpus_text = nltk.Text(words)
    frequency_distribution = FreqDist(corpus_text)
    top_words = frequency_distribution.most_common(top_size)
    if do_plots:
        print(f"\nThe frequency distribution for the top {top_size} words is plotted:\n{top_words}")
        frequency_distribution.plot(top_size, cumulative=True)

#----Analyse how words relate. Awful library prints as it analyses, so I can't return the results
#----and print later. What sucky software.
    print(f"\nThe collocations are:")
    colocations = corpus_text.collocations()
    print(f"\nThe concordance for the top word \"{top_words[0][0]}\" is:")
    top_word_concordance = corpus_text.concordance(top_words[0][0])
    if do_plots:
        plain_top_words = [row[0] for row in top_words]
        print(f"\nThe dispersion of the top {top_size} words is plotted:\n")
        corpus_text.dispersion_plot(plain_top_words[:top_size])

#----Extract bigrams and trigrams
    bigram_measures = nltk.collocations.BigramAssocMeasures()
    bigram_finder = BigramCollocationFinder.from_words(words)
    bigram_finder.apply_freq_filter(3)
    top_bigrams = bigram_finder.nbest(bigram_measures.pmi,top_size)
    print(f"The top {top_size} bigrams with at least 3 occurrences are:\n{top_bigrams}")

    trigram_measures = nltk.collocations.TrigramAssocMeasures()
    trigram_finder = TrigramCollocationFinder.from_words(words)
    # scored_trigrams = trigram_finder.score_ngrams(trigram_measures.raw_freq)
    top_trigrams = sorted(trigram_finder.nbest(trigram_measures.raw_freq,top_size))
    print(f"\nThe top {top_size} trigrams are:\n{top_trigrams}")
#--------------------------------------------------------------------------------------------------
def grammar_analysis(words):

    # print(f"\n")
    # tags = nltk.help.upenn_tagset()
    # print(f"\nThe tags are:\n{tags}")
    tagged_words = nltk.pos_tag(words)
# Labtask: Use the tags to specify how the lemmatizing should work.

    grammar = "NP: {<DT>?<JJ>*<NN>}"
    chunk_parser = nltk.RegexpParser(grammar)
    chunked_tree = chunk_parser.parse(tagged_words)

    grammar = "Chunk: {<.*>+}\n}<JJ>{"
    chink_parser = nltk.RegexpParser(grammar)
    chinked_tree = chink_parser.parse(tagged_words)

    named_entities_tree = nltk.ne_chunk(tagged_words)

    return tagged_words,chunked_tree,chinked_tree,named_entities_tree
#--------------------------------------------------------------------------------------------------
def display_tree(tree):
    # This specifically asks the tree for its SVG data
    try:
        # For modern NLTK in Colab/Jupyter
        svg_data = tree._repr_svg_()
        display(SVG(svg_data))
    except AttributeError:
        # Fallback to ASCII if SVG isn't supported in your version
        print("SVG not supported, falling back to ASCII:")
        tree.pretty_print()
#--------------------------------------------------------------------------------------------------
def main():

    top_size = 20
    do_plots = True

    text = get_text("corpus")
    # print(f"\nThe text is:\n{text}")

    sentences,words,filtered_words,lemmatized_words = clean_text(text)
    print(f"\nThe words are:\n{words[:top_size]}")
    print(f"\nThe filtered words are:\n{filtered_words[:top_size]}")
    print(f"\nThe lemmatized filtered words are:\n{lemmatized_words[:top_size]}")
    print(f"\nThe sentences are:\n{sentences[:top_size]}")

#----Sucky software does not allow results to be easily captured, so see the output in the function
    statistical_analysis(lemmatized_words,top_size,do_plots)

    tagged_words,chunked_tree,chinked_tree,named_entities_tree = grammar_analysis(filtered_words)
    print(f"\nThe tagged words are:\n{tagged_words[:top_size]}")
    if do_plots:
        print(f"\nThe chunked tree is:")
        display_tree(chunked_tree)
        print(f"\nThe chinked_tree is:")
        display_tree(chinked_tree)
        print(f"\nThe named entities tree is:")
        display_tree(named_entities_tree)

    print("\n")
#--------------------------------------------------------------------------------------------------
if __name__ == "__main__":
    main()
#--------------------------------------------------------------------------------------------------
