# -*- coding: utf-8 -*-
import dataiku
import itertools
import os
import tempfile
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from bs4 import BeautifulSoup

INDEX_FOLDER = dataiku.Folder("DQwRivV3")
DOCUMENTS_FOLDER = dataiku.Folder("9v0Qe4fl")

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "openai_key":
        os.environ["OPENAI_API_KEY"] = secret["value"]

def split_long_string(text, max_length=1000):
    """
    Split a long string into chunks of approximately `max_length` characters,
    making sure to split only at spaces.

    Parameters:
        text (str): The input string to be split.
        max_length (int): The maximum length for each chunk (default is 1000).

    Returns:
        list: A list containing the chunks of the original string.
    """
    if len(text) <= max_length:
        return [text]

    chunks = []
    while len(text) > max_length:
        # Find the last space before the max_length position
        split_index = text.rfind(" ", 0, max_length)
        
        # If no space is found, split at the max_length position
        if split_index == -1:
            split_index = max_length

        # Append the chunk to the list and update the text
        chunks.append(text[:split_index])
        text = text[split_index:].lstrip()

    # Append the remaining part of the text as the last chunk
    if text:
        chunks.append(text)

    return chunks

def split_html(html_content):
    """
    Compute embeddings from chunks
    """
    soup = BeautifulSoup(html_content, "html.parser")
    paragraphs = soup.find_all('p')

    chunks = []

    for p in paragraphs:
        title_list = []
        title = p.find_previous(['h1', 'h2', 'h3', 'h4', 'h5', 'h6'])
        while title is not None:
            try:
                title_tag = int(title.name[1])
            except:
                break
            title_list.append(title.get_text())
            title = title.find_previous([f"h{i}" for i in range(1, title_tag)])
        title_list.reverse()
        title_chain = " > ".join(title_list)
        splitted_paragraph = split_long_string(p.get_text())
        for s in splitted_paragraph:
            chunks.append("\n".join((title_chain, s)))
    return chunks
        
def extract_chunks_from_folder(folder):
    """
    Extract chunks from the documents in `folder`
    """
    total_chunks = []
    paths = folder.list_paths_in_partition()
    for path in paths:
        if path.lower().endswith(".html"):
            with folder.get_download_stream(path) as stream:
                chunks_list = split_html(stream.read())
                for chunk in chunks_list:
                    total_chunks.append(
                        Document(
                            page_content=chunk,
                            metadata={"docs": "manual"}
                        )
                    )
    return(total_chunks)

def embed_chunks(chunks):
    """
    Compute embeddings from chunks
    """
    embeddings = OpenAIEmbeddings()
    index = FAISS.from_documents(chunks, embeddings)
    return index

index = embed_chunks(extract_chunks_from_folder(DOCUMENTS_FOLDER))

with tempfile.TemporaryDirectory() as temp_dir:
    index.save_local(temp_dir)
    for f in os.listdir(temp_dir):
        INDEX_FOLDER.upload_file(f, os.path.join(temp_dir, f))