# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import re
import dataiku
import base64
import pandas as pd
from langchain.text_splitter import CharacterTextSplitter

CHUNK_SIZE = 800
CHUNK_OVERLAP = 100

folder = dataiku.Folder("vElSoRUz")
already_seen = set()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
pattern_headings = re.compile(r"^(#+) (\S.*)\[¶\]\((.*)\)$")
def get_heading_depth(s):
    m = pattern_headings.match(s)
    if m is not None:
        return len(m.group(1)), m.group(2), m.group(3).split(" ")[0]
    else:
        return 0, "", ""

def split(text, separator="\n\n", chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
    """
    Cut passages in small (potentially overlapping) chunks
    """
    text_splitter = CharacterTextSplitter(
        separator = separator,
        chunk_size = chunk_size,
        chunk_overlap  = chunk_overlap,
        length_function = len,
    )
    return text_splitter.split_text(text)

pattern_sequential_new_lines = re.compile(r"\n\s*\n")
def strip(s):
    """
    Remove consecutive \n
    """
    return pattern_sequential_new_lines.sub("\n", s.strip())

pattern_spaces = re.compile(r"^\s*(\S.*\S)\s*$|^\s*(\S)\s*$|^\s*()\s*$", re.DOTALL)
def strip_spaces(text):
    """
    Remove leading and trailing white spaces
    """
    m = pattern_spaces.fullmatch(text)
    if m:
        for i in range(1, 4):
            found = m.group(i)
            if found is not None:
                return found
        else:
            return text

def post_process(s):
    """
    Add code delimiters to get proper Markdown syntax
    """
    result = ""
    ongoing_code = False
    for line in strip(s).split("\n"):
        if line.startswith("§ "):
            if len(line) > 2:
                if ongoing_code:
                    result += f"\n{line[2:]}"
                else:
                    result += f"\n```\n{line[2:]}"
            ongoing_code = True
        else:
            if ongoing_code:
                result += f"\n```\n{line}"
            else:
                result += f"\n{line}"
            ongoing_code = False
    if ongoing_code:
        result += f"\n```"
    return result[1:]

def extract_content(content):
    """
    Transform a text passage into a dataframe of text chunks
    """
    lines = content.split("\n")
    chunks = {"href": [], "title": [], "content": []}

    current = ""
    current_url = ""
    position = []
    for i in range(len(lines)):
        # Check whether the current line is a heading
        depth, heading, heading_url = get_heading_depth(lines[i])
        if depth > 0:
            # If it's a heading, cut the content of the previous (sub)section in chunks
            current = strip_spaces(current).replace("[", "\\[").replace("]", "\\]").replace("\_", "_").replace("``", "")
            if len(current) > 1:
                for chunk in split(current):
                    if chunk not in already_seen:
                        already_seen.add(chunk)
                        chunk = f"{' > '.join(position)}\n{post_process(strip_spaces(chunk))}"
                        chunks["href"].append(current_url)
                        title, *chunk = chunk.split("\n")
                        chunks["title"].append(title)
                        chunks["content"].append("\n".join(chunk))
                current = ""
            # Whether it's a deeper level or not, update the current (sub)section
            if depth > len(position):
                position.append(heading)
            else:
                position = position[:depth]
                position[depth-1] = heading
            current_url = heading_url
        # If it isn't a heading, add the line to the content of the current (sub)section
        elif len(current) == 0:
            current = lines[i]
        else:
            current = current + "\n" + lines[i]
    # Cut the content of the remaining (sub)section in chunks
    current = strip_spaces(current).replace("[", "\\[").replace("]", "\\]").replace("\_", "_").replace("``", "")
    if len(current) > 0:
        for chunk in split(current):
            if chunk not in already_seen:
                already_seen.add(chunk)
                chunk = f"{' > '.join(position)}\n{post_process(strip_spaces(chunk))}"
                chunks["href"].append(current_url)
                title, *chunk = chunk.split("\n")
                chunks["title"].append(title)
                chunks["content"].append("\n".join(chunk))

    return pd.DataFrame.from_dict(chunks)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
result = None

for file in folder.list_paths_in_partition():
    with folder.get_download_stream(file) as f:
        content = f.read().decode()
        df = extract_content(content)
        if result is None:
            result = df
        else:
            result = pd.concat((result, df), axis=0, ignore_index=True)

result["id"] = range(len(result))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("chunks").write_with_schema(result)