import logging
import os
from abc import ABCMeta, abstractmethod

import six

import dataiku
from dataiku.base.utils import safe_unicode_str
import numpy as np
import pandas as pd
from six.moves import xrange


from dataiku.doctor.preprocessing.dataframe_preprocessing import Step
from dataiku.doctor.utils.gpu_execution import SentenceEmbeddingGpuCapability

logger = logging.getLogger(__name__)

QUERY_MAX_LEN = 1000  # max nb of records sent to the embedding api at a time.


@six.add_metaclass(ABCMeta)
class AbstractSentenceEmbeddingExtractor(Step):
    def __init__(self, column_name, model_name):
        self.column_name = column_name
        self.model_name = model_name

    def __str__(self,):
        return "Step:%s (%s)" % (self.__class__.__name__, self.column_name)

    def process(self,  input_df, current_mf, output_ppr, generated_features_mapping):
        """
        Applies the text embedding preprocessing. After extracting the text embedding
        to get a vector out of the text column, it applies a logic similar to UnfoldVectorProcessor.

        :param input_df: Contains text data to be processed
        :param current_mf: Multiframe to which to append processed data
        :param output_ppr: Current output preprocessing result
        :param generated_features_mapping: Holds mappings of block<->column
        """
        texts_series = input_df[self.column_name]
        if texts_series.empty:
            # No row to process, but we should still return an embedding array of the same size the model would return
            # for non-empty series: it's the size of last hidden dimension of the model.
            # This ensure the preprocessing pipeline will always return same size outputs.
            # DSS default dtype is float64, better stay compliant there
            embeddings = np.empty((0, self.get_output_dimension()), dtype=np.float64)

        else:
            embeddings = self.extract_embeddings(texts_series)
        # block_name used to be called `sentence_embedding:{feature}" prior 12.4
        block_name = u"sentence_vec:{}".format(safe_unicode_str(self.column_name))
        # When used on a feature called text_summary, full name will be e.g. sentence_vec:text_summary:206
        names = [u"sentence_vec:{}:{}".format(safe_unicode_str(self.column_name), i) for i in xrange(embeddings.shape[1])]
        generated_features_mapping.add_features_to_block(names, block_name)
        generated_features_mapping.add_whole_block_mapping(block_name, [self.column_name])
        current_mf.append_np_block(block_name, embeddings, names)

    @abstractmethod
    def get_output_dimension(self):
        raise NotImplementedError

    @abstractmethod
    def extract_embeddings(self, texts):
        raise NotImplementedError


class LLMApiSentenceEmbeddingExtractor(AbstractSentenceEmbeddingExtractor):
    """
    Vectorize a text column calling the embedding API using LLM connections to compute the text embeddings.
    when selecting a local Huggingface model, the preprocessing is made in a different kernel than the current process.
    """
    def __init__(self, column_name, model_name, model_embedding_size=None):
        """
        :param column_name: Name of a text column to be processed
        :param model_name: Name of the Sentence Transformer Model (legacy) or structure ref id of the model to use.
        """
        super(LLMApiSentenceEmbeddingExtractor, self).__init__(column_name, model_name)

        # TODO : This will not work on apinode which doesn't allow public api calls yet, see:
        #  https://app.shortcut.com/dataiku/story/162601
        self._project_handle = dataiku.api_client().get_default_project()
        self.llm_model = self._project_handle.get_llm(model_name)
        if model_embedding_size is None:
            query = self.llm_model.new_embeddings(text_overflow_mode="TRUNCATE")
            query.add_text("This is just a dummy query to get embeddingSize")
            model_embedding_size = len(query.execute().get_embeddings()[0])
        self.model_embedding_size = model_embedding_size

    def get_output_dimension(self):
        return self.model_embedding_size

    def extract_embeddings(self, texts):
        # call LLM embedding api with QUERY_MAX_LEN texts at a time:
        # todo this introduces duplicates of DKUEmbeddings.embed_documents(). +could be a common class with image embedding
        output = np.zeros((len(texts), self.get_output_dimension()), dtype=np.float64)
        embeddings = []
        indexes = []
        for i in range(0, len(texts), QUERY_MAX_LEN):
            query = self.llm_model.new_embeddings(text_overflow_mode="TRUNCATE")

            query_backend = False
            for index, text in enumerate(texts[i:i+QUERY_MAX_LEN]):
                # we skip missing values, their embeddings will be zeros
                if pd.isna(text) or pd.isnull(text) or not text or text.isspace():
                    continue
                query.add_text(text)
                indexes.append(index + i)
                query_backend = True

            if not query_backend:
                continue
            resp = query.execute()

            embeddings.extend(resp.get_embeddings())
            logger.info("Retrieved a response from LLM api. Embedded {num_embedded} of {num_texts} texts".format(
                num_embedded=min(i + QUERY_MAX_LEN, len(texts)), num_texts=len(texts)))

        logger.info("Retrieved all the responses from LLM api. Embedded {num_texts} texts".format(num_texts=len(texts)))
        if indexes:
            output[indexes, :] = embeddings
        return output

class CodeEnvResourceSentenceEmbeddingExtractor(AbstractSentenceEmbeddingExtractor):
    """
    Legacy + custom models handling mode:
    Vectorize a text column using text embedding models from SentenceTransformers
    https://www.sbert.net/#usage.
    Preprocessing is made in-process using a model downloaded into code env resources
    """
    def __init__(self, column_name, model_name, max_sequence_length, batch_size=32, gpu_config=None):
        """
        :param column_name: Name of a text column to be processed
        :param model_name: Name of the Sentence Transformer Model (legacy) or structure ref id of the model to use.
        :param max_sequence_length: Maximum input length for each row (the rest will be truncated & ignored by the model)
        https://www.sbert.net/examples/applications/computing-embeddings/README.html#input-sequence-length
        :param batch_size: How many rows to process in parallel
        :param gpu_config:
        """
        super(CodeEnvResourceSentenceEmbeddingExtractor, self).__init__(column_name, model_name)

        # Legacy models from code env resources:
        self.batch_size = batch_size
        device = SentenceEmbeddingGpuCapability.get_device(gpu_config)
        self.model = self.load_sentence_embedding_model(max_sequence_length, device)

    def load_sentence_embedding_model(self, max_seq_length, device):
        """Load text embedding models from SENTENCE_TRANSFORMERS_HOME folder (code env resources)"""
        from sentence_transformers import SentenceTransformer

        # Unlike image embedding extraction we have to support the (legacy) models from the code env resources
        # if the model id is not a structured ref that means it's a model from the code env resources
        logger.info("Using model from code env resources")
        model_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME')
        if not model_folder:
            raise Exception("The environment variable SENTENCE_TRANSFORMERS_HOME is not defined.")
        model_path = os.path.join(model_folder, self.model_name.replace("/", "_"))
        if not os.path.exists(model_path):
            raise Exception("The text embedding model '{}' has not been downloaded to the code env resources.".format(self.model_name))

        model = SentenceTransformer(model_path, model_folder, device=device)
        model.max_seq_length = max_seq_length
        return model

    def get_output_dimension(self):
        return int(self.model.get_sentence_embedding_dimension())

    def extract_embeddings(self, texts):
        texts = texts.astype(str).tolist() # TODO sc 186251: handle empty values, they are currently treated as the "nan" string
        return self.model.encode(texts, batch_size=self.batch_size).astype(np.float64)
