# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
import regex as re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from finbert.finbert import *
from dku_utils.core import get_current_project_and_variables
from dku_utils.folders.pickles.folder_pickles import read_pickle_from_managed_folder, write_pickle_in_managed_folder

project, variables = get_current_project_and_variables()

# Read recipe inputs
all_documents = dataiku.Folder("cPB8dLy4")
keywords = dataiku.Dataset("keywords")
keywords_df = keywords.get_dataframe()

key_word_list = list(set([x for x in keywords_df.to_numpy().flatten().tolist() if str(x) != 'nan']))

categories = list(keywords_df.columns)
category_dict = {}

for category in categories:
    category_dict[category] = list(keywords_df[keywords_df[category].notnull()][category].unique())

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Load out of the box pretrained finbert model for sentiment analysis
tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
## find word from key word list

def find_window(text,key_word_list,threshold = 400):
    all_text = []
    result_dict = {}
    for word in key_word_list:
        result_list = []
        tuple_list = [(m.start(),m.end()) for m in re.finditer(word,text, re.IGNORECASE)]
        if tuple_list:
            for start,end in tuple_list:
                window_text = text[max(0,start-threshold):min(end+threshold,len(text))]
                result_list.append(window_text)
                all_text.append(window_text)
            result_dict[word] = result_list
    return result_dict, " ".join(all_text)

def find_window_with_category(text, category_dict, threshold = 400):
    all_text = []
    result_dict = {}
    for category, word_list in category_dict.items():
        word_dict = {}
        for word in word_list:
            result_list = []
            tuple_list = [(m.start(),m.end()) for m in re.finditer(word,text,re.IGNORECASE)]
            if tuple_list:
                for start,end in tuple_list:
                    window_text = text[max(0,start-threshold):min(end+threshold,len(text))]
                    sent_score,top_pred = finbert_predict(window_text)
                    result_list.append((start,window_text, sent_score,top_pred))
                    all_text.append(window_text)
                    word_dict[word] = result_list
                result_dict[category] = word_dict
    return result_dict, " ".join(all_text)


def finbert_predict(text, model = model, tokenizer = tokenizer):
    result = predict(text, model, tokenizer)
    result_dict = result.to_dict(orient = "index")
    avg_sent_score = result['sentiment_score'].mean()
    top_pred = result['prediction'].mode()[0]
    return avg_sent_score, top_pred

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
paths = all_documents.list_paths_in_partition()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE

for path in paths:
    document = read_pickle_from_managed_folder(project, 'all_documents', path)
    document['key_word_category_dict'],document['key_word_category_text'] = find_window_with_category(document['text'],category_dict)
    del document['text']
    write_pickle_in_managed_folder(project, 'category', document, path)