import logging
from abc import ABCMeta
from abc import abstractmethod

import numpy as np
import pandas as pd
import six

from dataiku.core import doctor_constants
from dataiku.core.dku_logging import LogLevelContext
from dataiku.doctor.prediction.explanations.score_to_explain import ScoreToExplain

logger = logging.getLogger(__name__)


@six.add_metaclass(ABCMeta)
class ExplainingEngine(object):

    @abstractmethod
    def explain(self, df):
        """
        :type df: pd.DataFrame
        :rtype: ExplainingResult
        """

    @abstractmethod
    def get_estimated_peak_number_cells_generated_per_row_explained(self):
        """
        The explaining engine relies on generating multiple rows *close* to the initial row to be explained, using
        various techniques. Then, those rows are concatenated and scored at once, which can create significant memory
        peaks.
        This method should return the max number of cells created per row explained, so that other objects (mainly
        batching strategies) can make informed decisions on how many rows should be explained at once
        :rtype: int
        """


@six.add_metaclass(ABCMeta)
class ScoreComputer(object):

    def __call__(self, data, other_score_to_align_with=None, matching_indices_in_other=None):
        """
        :type data: pd.DataFrame or np.ndarray
        :type other_score_to_align_with: ScoreToExplain or None
        :param other_score_to_align_with: [Optional] score to align computed score on (Only relevant for multiclass)
        :type matching_indices_in_other: np.ndarray or None
        :param matching_indices_in_other: [Optional] corresponding indices of data in the other score, when relevant
        :rtype: ScoreToExplain
        """


@six.add_metaclass(ABCMeta)
class ExplainingResult(object):

    @staticmethod
    def concat(explaining_result_list):
        """
        :type explaining_result_list: list[ExplainingResult]
        :rtype: ExplainingResult
        """
        if len(explaining_result_list) == 0:
            raise ValueError("cannot concat empty list")
        explaining_result_class = explaining_result_list[0].__class__
        if not all(isinstance(explaining_result, explaining_result_class)
                   for explaining_result in explaining_result_list):
            raise ValueError("All results should be of same class")
        return explaining_result_class._concat(explaining_result_list)

    @staticmethod
    @abstractmethod
    def _concat(explaining_result_list):
        """
        Actual implementation of the concat for the specified class
        :type explaining_result_list: list[ExplainingResult]
        :rtype: ExplainingResult
        """


class SimpleExplanationResult(ExplainingResult):

    def __init__(self, explanations_df):
        self.explanations_df = explanations_df

    @staticmethod
    def _concat(explaining_result_list):
        """
        :type explaining_result_list: list[SimpleExplanationResult]
        :rtype: SimpleExplanationResult
        """
        return SimpleExplanationResult(pd.concat([r.explanations_df for r in explaining_result_list], axis=0))


class GlobalExplanationResult(ExplainingResult):

    def __init__(self):
        self.explanations = {}

    def add_explanations(self, key, explanation_result):
        """
        :type key: str
        :type explanation_result: SimpleExplanationResult
        """
        self.explanations[key] = explanation_result

    @staticmethod
    def _concat(explaining_result_list):
        """
        :type explaining_result_list: list[GlobalExplanationResult]
        :rtype:GlobalExplanationResult
        """
        concat_result = GlobalExplanationResult()
        explanations_keys = explaining_result_list[0].explanations.keys()
        for r in explaining_result_list[1:]:
            if r.explanations.keys() != explanations_keys:
                raise ValueError("Discrepancies in the explanations dictionaries")

        for explanation_key in explanations_keys:
            concat_result_key = ExplainingResult.concat([r.explanations[explanation_key]
                                                         for r in explaining_result_list])
            assert isinstance(concat_result_key, SimpleExplanationResult)
            concat_result.add_explanations(explanation_key, concat_result_key)

        return concat_result

    def to_dicts(self):
        # Average by classes first if multiclass, then average by rows
        # https://github.com/slundberg/shap/issues/367#issuecomment-462073420
        abs_avg_explanations = (sum(
            [v.explanations_df.abs() for v in self.explanations.values()]
        ) / len(self.explanations)).mean(axis=0).to_dict()

        explanations = {
            k: v.explanations_df.to_dict(orient="list")
            for k, v in self.explanations.items()
        }
        return abs_avg_explanations, explanations


@six.add_metaclass(ABCMeta)
class BatchingStrategy(object):

    @abstractmethod
    def get_num_batches(self, num_rows):
        """
        :type num_rows: int
        :rtype: int
        """


class FixedSizeBatchingStrategy(BatchingStrategy):

    def __init__(self, batch_size):
        """
        :type batch_size: int or None
        """
        self._batch_size = batch_size

    def get_num_batches(self, num_rows):
        if self._batch_size is None:
            logger.info("No batch size defined, will run a single batch")
            return 1
        return int(np.ceil(1. * num_rows / self._batch_size))


class PeakCellCountBatchingStrategy(BatchingStrategy):

    def __init__(self, max_num_batches, max_singly_processed_cells, peak_num_cells_per_row_explained):
        self._max_num_batches = max_num_batches
        self._max_singly_processed_cells = max_singly_processed_cells
        self._peak_num_cells_per_row_explained = peak_num_cells_per_row_explained

    def get_num_batches(self, num_rows):
        estimated_processed_cells = (num_rows * self._peak_num_cells_per_row_explained)
        batch_size = min(self._max_num_batches, estimated_processed_cells // self._max_singly_processed_cells + 1, num_rows)
        logger.info("Selecting batch size of: {} (num_rows={}, max_num_batches={}, max_singly_processed_cells={}, "
                    "estimated_processed_cells={})".format(num_rows, batch_size, self._max_num_batches,
                                                           self._max_singly_processed_cells, estimated_processed_cells))
        return batch_size


class BatchingExplainingEngine(ExplainingEngine):

    def __init__(self, underlying_explaining_engine, batching_strategy, progress, debug_mode):
        """
        :type underlying_explaining_engine: ExplainingEngine
        :type batching_strategy: BatchingStrategy
        :type progress: dataiku.core.percentage_progress.PercentageProgress or None
        :type debug_mode: bool
        """
        self._underlying_explaining_engine = underlying_explaining_engine
        self._batching_strategy = batching_strategy
        self._progress = progress
        self._debug_mode = debug_mode

    def get_estimated_peak_number_cells_generated_per_row_explained(self):
        return self._underlying_explaining_engine.get_estimated_peak_number_cells_generated_per_row_explained()

    def explain(self, df):
        n_batches = self._batching_strategy.get_num_batches(df.shape[0])
        explaining_result_list = []
        num_computed_rows = 0
        num_total_rows = df.shape[0]
        logger.info("Explaining %s rows in %s batches" % (num_total_rows, n_batches))
        with LogLevelContext(logging.CRITICAL, [doctor_constants.PREPROCESSING_LOGGER_NAME], disable=self._debug_mode):
            for i, batch_indices in enumerate(np.array_split(np.arange(num_total_rows), n_batches)):
                df_batch = df.iloc[batch_indices]
                logger.info("Explaining %s rows (batch %s/%s)" % (df_batch.shape[0], i + 1, n_batches))
                explaining_result_list.append(self._underlying_explaining_engine.explain(df_batch))
                num_computed_rows += df_batch.shape[0]
                if self._progress is not None:
                    self._progress.set_percentage(int(100 * num_computed_rows / num_total_rows))

        result = ExplainingResult.concat(explaining_result_list)
        logger.info("Done computing explanations")
        return result
