# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import warnings                         # Disable some warnings
warnings.filterwarnings("ignore",category=DeprecationWarning)
import dataiku
from dataiku import pandasutils as pdu
import pandas as pd,  seaborn as sns
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.feature_extraction import text
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import LatentDirichletAllocation
import pyLDAvis.sklearn
from dataiku import insights
import joblib
from nltk.stem import PorterStemmer
import regex as re
from wordcloud import WordCloud
import io
from dku_utils.core import get_current_project_and_variables
from dku_utils.folders.pickles.folder_pickles import read_pickle_from_managed_folder

project, variables = get_current_project_and_variables()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
keyword = dataiku.Folder("WSwcnOo2")

paths = keyword.list_paths_in_partition()
all_data = []

for path in paths:
    all_data.append(read_pickle_from_managed_folder(project, 'category', path))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
raw_text_col = "key_word_category_text"
token_word = dict()

def clean_text(text):
    try:
        cut_text = int((len(text))*.20)
        text = text[cut_text:]
        clean = ' '.join(re.sub('\d+','', text).split())
        ps = PorterStemmer()
        tokens = []
        for word in clean.split(" "):
            if len(word) >= 4:
                stemmed_word = ps.stem(word)
                tokens.append(stemmed_word)
                try:
                    token_word[stemmed_word][word] += 1
                except KeyError:
                    try:
                        token_word[stemmed_word][word] = 1
                    except KeyError:
                        token_word[stemmed_word] = {word: 1}
        #final = [ps.stem(token) for token in clean.split(" ") if len(token) >= 4]
        return " ".join(tokens)
    except Exception as e:
        print(e)
        return text

clean_text = [clean_text(data.get(raw_text_col)) for data in all_data]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
top_word_token = {token: max(words, key=words.get) for token, words in token_word.items()}

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
custom_stop_words = [u'filing', u'bank', u'congress']

stop_words = text.ENGLISH_STOP_WORDS.union(custom_stop_words)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
cnt_vectorizer = CountVectorizer(strip_accents = 'unicode',stop_words = stop_words,lowercase = True,
                                token_pattern = r'\b[a-zA-Z]{3,}\b', max_df = 1 if len(clean_text)==1 else 0.85, min_df = min(len(clean_text), 2))

text_cnt = cnt_vectorizer.fit_transform(clean_text)


# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
tfidf_vectorizer = TfidfVectorizer(strip_accents = 'unicode',stop_words = stop_words,lowercase = True,
                                token_pattern = r'\b[a-zA-Z]{3,}\b', max_df = 1 if len(clean_text)==1 else 0.75, min_df = min(len(clean_text), 2))

text_tfidf = tfidf_vectorizer.fit_transform(clean_text)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
n_topics = 3

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# Use this line for LDA

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
topics_model = LatentDirichletAllocation(n_topics, random_state=0)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
topics_model.fit(text_tfidf)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
n_top_words = 20
feature_names = tfidf_vectorizer.get_feature_names()

def get_top_words_topic(topic_idx):
    topic = topics_model.components_[topic_idx]

    print( [feature_names[i] for i in topic.argsort()[:-n_top_words - 1:-1]] )

for topic_idx, topic in enumerate(topics_model.components_):
    print ("Topic #%d:" % topic_idx )
    get_top_words_topic(topic_idx)
    print ("")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
viz = pyLDAvis.sklearn.prepare(topics_model, text_tfidf, tfidf_vectorizer)
viz_html = pyLDAvis.prepared_data_to_html(viz)
insights.save_data('pyldavis', viz_html, 'text/html')

# Write recipe outputs and save model to folder
topic_modeling_insights = dataiku.Folder("IKH0PQKV")
topic_modeling_insights_info = topic_modeling_insights.get_info()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def top_documents_topics(topic_name, n_doc = 3, excerpt = True):
    '''This returns the n_doc documents most representative of topic_name'''
    dict_topic_name = {0:"Social", 1:"Environmental", 2:'Governance'} #Define here your own name mapping and uncomment this !
    # retrieve the document-topic matrix
    document_model = pd.DataFrame(topics_model.transform(text_tfidf))
    document_model.columns.name = 'topic'
    document_model.rename(columns = dict_topic_name, inplace = True) #naming topics
    document_index = list(document_model[topic_name].sort_values(ascending = False).index)[:n_doc]
    for order, i in enumerate(document_index):
        print("Text for the {}-th most representative document for topic {}:\n".format(order + 1,topic_name))
        if excerpt:
            print(clean_text[i][:1000])
        else:
            print(clean_text[i])
        print("\n******\n")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# most important words for each topic
vocab = tfidf_vectorizer.get_feature_names()

# Generate a word cloud image for given topic
def draw_word_cloud(index, path = path, folder = topic_modeling_insights):
    imp_words_topic=""
    comp=topics_model.components_[index]
    vocab_comp = zip(vocab, comp)
    sorted_words = sorted(vocab_comp, key= lambda x:x[1], reverse=True)[:20]
    for word in sorted_words:
        corrected_word = top_word_token.get(word[0])
        if corrected_word is not None:
            imp_words_topic=imp_words_topic+" "+corrected_word
        else:
            imp_words_topic=imp_words_topic+" "+word[0]

    wordcloud = WordCloud(width=600, height=400, background_color = "white").generate(imp_words_topic)
    plt.figure()
    plt.axis("off")
    bs = io.BytesIO()
    plt.imshow(wordcloud)
    plt.savefig(bs, format="jpg")
    topic_modeling_insights.upload_stream('word_cloud_' + str(index) + ".jpg",bs.getvalue())
    id = "word_cloud_topic_" + str(index)
    dataiku.insights.save_figure(id)

for topic_idx, topic in enumerate(topics_model.components_):
      draw_word_cloud(topic_idx, path, topic_modeling_insights)