import base64
import logging
import io

import numpy as np
import pandas as pd
from PIL import Image
from scipy import ndimage
from scipy.stats import entropy

from dataiku.core.image_loader import ImageLoader
from dataiku.doctor.deep_learning.preprocessing import DummyFileReader
from dataiku.doctor.preprocessing.multimodal_preprocessings.image_embedding_extraction import \
    MAX_CONCURRENT_IMAGES_IN_MEM
from dataiku.modelevaluation.drift.drift_univariate import ks_test, comparative_numerical_histogram
from dataiku.modelevaluation.drift.embedding_drift_settings import ImageDriftSettings

logger = logging.getLogger(__name__)


class DriftImageQuality(object):
    """Calculates a suite of image quality drift metrics between two datasets."""

    def __init__(self, ref_series: pd.DataFrame, cur_series: pd.DataFrame,
                 image_drift_settings: ImageDriftSettings,
                 nb_bins: int = 100,
                 handle_drift_failure_as_error: bool = False):
        """Initializes the DriftImageQuality calculator.

        Args:
            ref_series (pd.DataFrame): The reference dataset.
            cur_series (pd.DataFrame): The current dataset.
            image_drift_settings (ImageDriftSettings): Configuration for the computation.
            nb_bins (int): The number of bins for histogram-based calculations.
            handle_drift_failure_as_error (bool): Flag to control error handling.
        """
        self.nb_bins = nb_bins
        self.ref_series = ref_series
        self.cur_series = cur_series
        self.image_drift_settings = image_drift_settings
        self.handle_drift_failure_as_error = handle_drift_failure_as_error

    def compute_drift(self) -> dict:
        """Computes drift for all image quality metrics across all columns.

        Returns:
            dict: A dictionary containing the nested drift results. The structure is::

            {
                "columns": {
                    "column_name_1": {
                        "mean_red": {
                            "name": "mean_red",
                            "type": "NUMERICAL",
                            "histogram": { ... },
                            "ksTestStatistic": 0.15,
                            "ksTestPvalue": 0.02
                        },
                        "entropy": { ... }
                    },
                    "column_name_2": { ... }
                }
            }
        """

        column_results = {}

        for column in self.cur_series:
            logger.info(u"Image drift: computing Image quality metrics for column {col}".format(col=column))
            ref_df = self._get_metrics_dataframe(self.image_drift_settings.managed_folder_smart_id_ref,
                                                 self.ref_series[column])
            cur_df = self._get_metrics_dataframe(self.image_drift_settings.managed_folder_smart_id_cur,
                                                 self.cur_series[column])
            column_results[column] = {}
            for metric in ref_df.columns.values:
                column_results[column][metric] = self._compute_metric(metric, ref_df[metric], cur_df[metric])

        return {"columns": column_results}

    def _compute_metric(self, column: str, ref_series: pd.DataFrame, cur_series: pd.DataFrame) -> dict:
        ks_test_statistic, ks_test_pvalue = ks_test(ref_series, cur_series, column, self.handle_drift_failure_as_error)

        histogram = comparative_numerical_histogram(ref_series, cur_series, self.nb_bins)

        return {
            "name": column,
            "type": "NUMERICAL",
            "histogram": histogram,
            "ksTestStatistic": ks_test_statistic,
            "ksTestPvalue": ks_test_pvalue,
        }

    def _get_metrics_dataframe(self, managed_folder_smart_id: str, img_paths: pd.Series) -> pd.DataFrame:
        results = []
        file_reader = DummyFileReader(managed_folder_smart_id)

        for batch_index in range(0, len(img_paths), MAX_CONCURRENT_IMAGES_IN_MEM):
            indices_to_load = np.arange(len(img_paths))[batch_index:batch_index + MAX_CONCURRENT_IMAGES_IN_MEM]
            images = ImageLoader(True, file_reader).load_images(img_paths[indices_to_load])

            for image in images:
                metrics = self._calculate_metrics(image)
                results.append(metrics)

        df = pd.DataFrame(results)
        return df

    def _calculate_metrics(self, image_b64: str) -> dict:

        image_data = base64.b64decode(image_b64)
        pil_image = Image.open(io.BytesIO(image_data))

        image_rgb = np.array(pil_image.convert('RGB'))
        image_gray = np.array(pil_image.convert('L'))
        image_hsv = np.array(pil_image.convert('HSV'))

        height, width = image_gray.shape
        aspect_ratio = width / height if height != 0 else 0
        area = width * height

        mean_r, mean_g, mean_b = image_rgb.mean(axis=(0, 1))
        mean_saturation = image_hsv[:, :, 1].mean()
        rms_contrast = image_gray.std()
        laplacian_var = ndimage.laplace(image_gray.astype(np.float64)).var()

        sobel_x = ndimage.sobel(image_gray.astype(np.float64), axis=0)
        sobel_y = ndimage.sobel(image_gray.astype(np.float64), axis=1)
        tenengrad = np.mean(sobel_x ** 2 + sobel_y ** 2)

        counts = np.histogram(image_gray, bins=256, range=(0, 256))[0]
        probabilities = counts[counts > 0] / counts.sum()
        shannon_entropy = entropy(probabilities, base=2)

        gradient_magnitude = np.hypot(sobel_x, sobel_y)
        edge_threshold = gradient_magnitude.mean() * 1.5
        edges = gradient_magnitude > edge_threshold
        edge_density = np.sum(edges) / (height * width)

        return {
            'meanRed': mean_r,
            'meanGreen': mean_g,
            'meanBlue': mean_b,
            'meanSaturation': mean_saturation,
            'rmsContrast': rms_contrast,
            'laplacianVar': laplacian_var,
            'tenengrad': tenengrad,
            'entropy': shannon_entropy,
            'edgeDensity': edge_density,
            'area': area,
            'aspectRatio': aspect_ratio,
        }
