import logging
from typing import Iterator, Generator, List, Optional, Union, Sequence

import pandas as pd

try:
    from langchain_core.documents import Document
    from langchain_core.document_loaders.base import BaseLoader
    from langchain_text_splitters.character import RecursiveCharacterTextSplitter
except ModuleNotFoundError:
    from langchain.docstore.document import Document
    from langchain.document_loaders.base import BaseLoader
    from langchain.text_splitter import RecursiveCharacterTextSplitter

from dataiku import Dataset
from dataiku.langchain.metadata_generator import MetadataGenerator

logger = logging.getLogger(__name__)

DEFAULT_DATASET_BATCH_SIZE = 10000

class VectorStoreLoader(BaseLoader):
    def lazy_load(self) -> Iterator[List[Document]]:
        raise NotImplementedError()

    def iter_documents(self) -> Generator[Document, None, None]:
        for documents in self.lazy_load():
            for document in documents:
                yield document

    def load(self) -> List[Document]:
       return list(self.iter_documents())


class DummyLoader(VectorStoreLoader):
    """Simple dummy loader that does no chunking.

    This is useful for small tests, e.g. checking that the embedding
    size is compatible with an existing vector store."""
    def __init__(self, input_data: Sequence[str]):
        self.documents = [Document(s) for s in input_data]

    def lazy_load(self) -> Iterator[List[Document]]:
        yield self.documents


class DatasetLoader(VectorStoreLoader):
    def __init__(self,
                 input_data: Union[Dataset, "InlineDataset"],  # not imported to avoid circular import error
                 content_column: str,
                 source_id_column: Optional[str] = None,
                 metadata_columns: Optional[List[str]] = None,
                 security_tokens_column: Optional[str] = None,
                 limit: Optional[int] = None,
                 chunk_size: Optional[int] = DEFAULT_DATASET_BATCH_SIZE):
        self.input_data = input_data
        self.content_column = content_column
        self.metadata_generator = MetadataGenerator(
            metadata_columns or [],
            source_id_column,
            security_tokens_column
        )
        # For backwards compatibility where unlimited was defined as limit = -1
        self.limit = None if limit == -1 else limit
        self.chunk_size = chunk_size

    def lazy_load(self) -> Iterator[List[Document]]:
        skipped = 0
        total = 0
        limit_reached = False
        for df in self.input_data.iter_dataframes(chunksize=self.chunk_size):
            remaining = min(len(df.index), self.limit - total) if self.limit is not None else len(df.index)
            # we intentionally log the remaining count to avoid discrepancy with the total documents count
            logger.info(f"Loading a batch of {remaining} documents from the dataset ({total} documents loaded so far)")
            documents = []
            for row in df.iterrows():
                if self.limit is not None and total >= self.limit:
                    logger.info(f"Maximum number of records reached, skipping the {self.chunk_size - remaining} remaining documents from the dataset")
                    limit_reached = True
                    break
                total += 1
                row = row[1]
                raw_content = row[self.content_column]
                # Skip for empty content, since it cannot be embedded anyway
                # won't be a problem with upserts since if the doc is not passed, the record manager will delete it if it
                # was previously indexed
                if pd.isna(raw_content):
                    skipped += 1
                    continue

                # Casting to string to avoid issues with non-string content
                document = Document(page_content=str(raw_content), metadata=self.metadata_generator.to_metadata(row))
                documents.append(document)
            yield documents
            if limit_reached:
                # skip the iter_dataframes() loop since we do not need to read more from the dataset
                break
        if skipped > 0:
            logger.warning(f"Skipped {skipped} documents with empty content")
        if total > 0:
            logger.info(f"Loaded {total} documents from the dataset")
        else:
            logger.info("Empty dataset, nothing to embed")


class TextSplitterLoader(VectorStoreLoader):
    def __init__(self, documents_loader: VectorStoreLoader, chunk_size: int, chunk_overlap: int):
        self.documents_loader = documents_loader
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )

    def lazy_load(self) -> Iterator[List[Document]]:
        logger.info("Performing splitting")
        total_docs = 0
        total_split = 0
        for documents in self.documents_loader.lazy_load():
            split_documents = self.text_splitter.split_documents(documents)
            total_docs += len(documents)
            total_split += len(split_documents)
            yield split_documents
        if total_split > total_docs:
            logger.info(f"After splitting, expanded {total_docs} documents to {total_split} records to embed")
        else:
            logger.info("No splitting was performed")
