import logging

import enum
import pandas as pd

from dataiku.core.doctor_constants import NUMERIC, TEXT, IMAGE, VECTOR
from dataiku.doctor.deep_learning.preprocessing import DummyFileReader
from dataiku.modelevaluation.drift.embedding_drift_settings import TextDriftSettings, ImageDriftSettings
from dataiku.doctor.exception import DriftException
from dataiku.doctor.preprocessing.multimodal_preprocessings.image_embedding_extraction import ImageEmbeddingExtractor
from dataiku.doctor.preprocessing.multimodal_preprocessings.sentence_embedding_extraction import LLMApiSentenceEmbeddingExtractor
from dataiku.modelevaluation.data_types import cast_as_numeric, cast_as_string
from dataiku.modelevaluation.drift.prepared_features import PreparedFeatures

logger = logging.getLogger(__name__)


class ResolvedColumnHandling(enum.Enum):
    NUMERICAL = 1
    CATEGORICAL = 2
    TEXT = 3
    IMAGE = 4
    IGNORED = 5
    UNSUPPORTED = 6


class DriftPreparator(object):
    """
    Prepare reference & current dataframes by applying (the same) drift column handling parameters
    => Ensure the two dataframes have *exactly* the same schema after preparation
    """

    def __init__(self, original_ref_me, original_cur_me, data_drift_params, can_compute_embedding_drift=False,
                 text_drift_settings=TextDriftSettings(should_be_computed=False),
                 image_drift_settings=ImageDriftSettings(should_be_computed=False)):
        self.original_ref_df = original_ref_me.sample_df
        self.original_cur_df = original_cur_me.sample_df
        self.ref_preprocessing = original_ref_me.preprocessing_params
        self.cur_preprocessing = original_cur_me.preprocessing_params
        self.data_drift_params = data_drift_params
        self.can_compute_embedding_drift = can_compute_embedding_drift    # Are we able to do embedding drift (ie, not possible in interactive mode)
        self.text_drift_settings = text_drift_settings
        self.image_drift_settings = image_drift_settings

    def _infer_column_handling(self, column):
        """
        Determine the type of a column for drift analysis from multiple sources:
        - Drift column params (if they are defined for this column)
        - MEs (or ME-like)'s preprocessings
        - Pandas type
        """

        ref_feature_handling = self.ref_preprocessing["per_feature"].get(column)
        cur_feature_handling = self.cur_preprocessing["per_feature"].get(column)

        if ref_feature_handling["type"] == NUMERIC and cur_feature_handling["type"] == NUMERIC:
            default_handling = ResolvedColumnHandling.NUMERICAL
        elif ref_feature_handling["type"] == IMAGE and cur_feature_handling["type"] == IMAGE:
            if not self.image_drift_settings.should_be_computed:
                logger.info("Column %s is detected as Image but the Image Drift is disabled : ignored" % column)
                default_handling = ResolvedColumnHandling.UNSUPPORTED
            else:
                default_handling = ResolvedColumnHandling.IMAGE
        elif ref_feature_handling["type"] == VECTOR and cur_feature_handling["type"] == VECTOR:
            default_handling = ResolvedColumnHandling.UNSUPPORTED
        elif ref_feature_handling["type"] == TEXT and cur_feature_handling["type"] == TEXT:
            if not self.text_drift_settings.should_be_computed:
                logger.info("Column %s is detected as Text but the Text Drift is disabled : ignored" % column)
                default_handling = ResolvedColumnHandling.UNSUPPORTED
            else:
                default_handling = ResolvedColumnHandling.TEXT
        else:
            default_handling = ResolvedColumnHandling.CATEGORICAL

        drift_col_params = self.data_drift_params.columns.get(column)
        if drift_col_params:
            if not drift_col_params.get("enabled", False):
                actual_handling = ResolvedColumnHandling.IGNORED if default_handling != ResolvedColumnHandling.UNSUPPORTED else ResolvedColumnHandling.UNSUPPORTED
            elif "handling" not in drift_col_params or drift_col_params["handling"] == "AUTO":
                actual_handling = default_handling
            elif drift_col_params["handling"] == "NUMERICAL":
                actual_handling = ResolvedColumnHandling.NUMERICAL
            elif drift_col_params["handling"] == "TEXT":
                actual_handling = ResolvedColumnHandling.TEXT
            elif drift_col_params["handling"] == "IMAGE":
                actual_handling = ResolvedColumnHandling.IMAGE
            elif drift_col_params["handling"] == "VECTOR":
                actual_handling = ResolvedColumnHandling.UNSUPPORTED
            else:
                actual_handling = ResolvedColumnHandling.CATEGORICAL
        else:
            actual_handling = default_handling

        return actual_handling, default_handling

    def prepare(self):
        per_column_settings = []
        prepared_features_cur = PreparedFeatures()
        prepared_features_ref = PreparedFeatures()

        for column in self.list_available_columns():
            actual_handling, default_handling = self._infer_column_handling(column)
            logger.info(u"Treating {} as {} for drift analysis".format(column, actual_handling))

            settings = {
                "name": column,
                "actualHandling": actual_handling.name,
                "defaultHandling": default_handling.name
            }

            if actual_handling == ResolvedColumnHandling.NUMERICAL:
                try:
                    prepared_features_ref.tabular_series[column] = cast_as_numeric(self.original_ref_df[column])
                    prepared_features_cur.tabular_series[column] = cast_as_numeric(self.original_cur_df[column])
                except ValueError:
                    msg = u"Failed to cast {} as {} for drift analysis".format(column, actual_handling.name)
                    logger.info(msg)
                    settings["errorMessage"] = msg
                    prepared_features_ref.tabular_series.pop(column, None)
                    prepared_features_cur.tabular_series.pop(column, None)

            elif actual_handling == ResolvedColumnHandling.CATEGORICAL:
                # TODO: py2/p3 ok?
                prepared_features_ref.tabular_series[column] = cast_as_string(self.original_ref_df[column])
                prepared_features_cur.tabular_series[column] = cast_as_string(self.original_cur_df[column])

            elif actual_handling == ResolvedColumnHandling.TEXT and self.text_drift_settings.should_be_computed:
                self._extract_text_embeddings(prepared_features_ref.text_embeddings_by_column,
                                              prepared_features_cur.text_embeddings_by_column,
                                              column)

            elif actual_handling == ResolvedColumnHandling.IMAGE and self.image_drift_settings.should_be_computed:
                self._extract_image_embeddings(prepared_features_ref.image_embeddings_by_column,
                                               prepared_features_cur.image_embeddings_by_column,
                                               column)

            per_column_settings.append(settings)

        self._validate_inputs_for_drift(per_column_settings, prepared_features_ref)

        prepared_features_cur.tabular_series = pd.DataFrame(prepared_features_cur.tabular_series)
        prepared_features_ref.tabular_series = pd.DataFrame(prepared_features_ref.tabular_series)
        return prepared_features_ref, prepared_features_cur, per_column_settings

    def _validate_inputs_for_drift(self, per_column_settings, prepared_features_ref):
        if self.image_drift_settings.should_be_computed and len(prepared_features_ref.image_embeddings_by_column) == 0:
            logger.warning("Image drift is selected but no image column computed.")

        if self.text_drift_settings.should_be_computed and len(prepared_features_ref.text_embeddings_by_column) == 0:
            logger.warning("Text drift is selected but no text column computed.")

        if all([(column["actualHandling"] == ResolvedColumnHandling.IGNORED.name
                 or column["actualHandling"] == ResolvedColumnHandling.UNSUPPORTED.name)
                for column in per_column_settings]):
            if len(self.data_drift_params.columns) == 0:
                # The default case, without manual selection by the user : we just don't compute
                logger.warning("All the input features of the model are either ignored "
                               "or unsupported for this input data drift computation. Skipping input data drift.")
            else:
                raise DriftException("All the input features of the model are either ignored "
                                     "or unsupported for this input data drift computation. Skipping input data drift.")

    def _extract_text_embeddings(self, embedding_ref, embedding_cur, column):
        if not self.can_compute_embedding_drift:
            raise DriftException(("The column %s is to be handled as Text for drift computation but "
                                  "the Text Drift section is disabled. Please review your recipe configuration.") % column)

        if not self.text_drift_settings.params:
            raise DriftException("The text drift parameters cannot be null with the text drift enabled. "
                                 "Please review your recipe configuration.")

        if not self.text_drift_settings.params.get('embeddingModelId', None):
            raise DriftException("No text embedding model is selected, whereas the text drift is enabled. "
                                 "Please review your recipe configuration.")

        try:
            sentence_embedding_extractor = LLMApiSentenceEmbeddingExtractor(column,
                self.text_drift_settings.params.get("embeddingModelId"))
            embedding_cur[column] = sentence_embedding_extractor.extract_embeddings(self.original_cur_df[column].values)
            embedding_ref[column] = sentence_embedding_extractor.extract_embeddings(self.original_ref_df[column].values)
        except:
            raise DriftException(("Failed to compute text embeddings for column %s. "
                                  "Please review your recipe configuration.") % column)

    def _extract_image_embeddings(self, embedding_ref, embedding_cur, column):
        if not self.can_compute_embedding_drift:
            raise DriftException(("The column %s is to be handled as Image for drift computation but the Image Drift "
                                  "section is disabled. Please review your recipe configuration.") % column)

        if not self.image_drift_settings.params:
            raise DriftException("The image drift parameters cannot be null with the image drift enabled. "
                                 "Please review your recipe configuration.")

        if not self.image_drift_settings.params.get('embeddingModelId', None):
            raise DriftException("No image embedding model is selected, whereas the image drift is enabled. "
                                 "Please review your recipe configuration.")

        if not self.image_drift_settings.managed_folder_smart_id_cur:
            raise DriftException(("The column %s is to be handled as Image for drift computation but the Evaluation "
                                  "Data folder could not be found. Please review your recipe configuration.") % column)

        if not self.image_drift_settings.managed_folder_smart_id_ref:
            raise DriftException(("The column %s is to be handled as Image for drift computation but the "
                                  "Reference Data folder could not be found. "
                                  "Please review the Image location in your train model "
                                  "or recipe configuration.") % column)

        try:
            file_reader_cur = DummyFileReader(self.image_drift_settings.managed_folder_smart_id_cur)
            image_embedding_extractor_cur = ImageEmbeddingExtractor(
                column, file_reader_cur, self.image_drift_settings.params.get("embeddingModelId"))
            embedding_cur[column] = image_embedding_extractor_cur.extract_embeddings(self.original_cur_df)

            file_reader_ref = DummyFileReader(self.image_drift_settings.managed_folder_smart_id_ref)
            image_embedding_extractor_ref = ImageEmbeddingExtractor(
                column, file_reader_ref, self.image_drift_settings.params.get("embeddingModelId"))
            embedding_ref[column] = image_embedding_extractor_ref.extract_embeddings(self.original_ref_df)

        except:
            raise DriftException(("Failed to compute image embeddings for column %s. "
                                  "Please review your recipe configuration.") % column)

    def list_available_columns(self):
        columns = self.original_ref_df.columns.intersection(self.original_cur_df.columns)
        for processing in [self.ref_preprocessing["per_feature"], self.cur_preprocessing["per_feature"]]:
            columns = columns.intersection(set(feature for feature, feature_processing in processing.items() if feature_processing.get("role") == "INPUT"))
        return columns
