import logging
from typing import List
from dataiku.eda.types import Literal

import numpy as np
from sklearn import decomposition, preprocessing

from dataiku.eda.computations.computation import Computation, MultivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NotEnoughDataError, InvalidParams
from dataiku.eda.filtering.and_filter import AndFilter
from dataiku.eda.filtering.missing_filter import MissingFilter
from dataiku.eda.filtering.not_filter import NotFilter
from dataiku.eda.types import PCAModel, PCAResultModel

logger = logging.getLogger(__name__)


class PCA(MultivariateComputation):
    def __init__(self, columns: List[str], projection_computation: Computation, projection_dimension: int, input_data_prefix: str = "input_"):
        super(PCA, self).__init__(columns)
        self.projection_computation = projection_computation
        self.projection_dimension = projection_dimension
        self.input_data_prefix = input_data_prefix

    @staticmethod
    def get_type() -> Literal["pca"]:
        return "pca"

    @staticmethod
    def build(params: PCAModel) -> 'PCA':
        return PCA(params['columns'], Computation.build(params['projectionComputation']), params['projectionDimension'], params['inputDataPrefix'])

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> PCAResultModel:
        if len(self.columns) < 1:
            raise InvalidParams("Number of columns must be at least 1")

        if len(self.columns) < self.projection_dimension:
            raise InvalidParams("Number of columns must be greater or equal to projection dimension")

        # Drop all rows containing at least a missing value
        filtered_idf = AndFilter([NotFilter(MissingFilter(column)) for column in self.columns]).apply(idf)

        if len(filtered_idf) < len(self.columns):
            logger.info("The number of valid rows ({}) is less than the number of columns ({}) used for PCA: the number of components generated will be limited by the number of valid rows.".format(len(filtered_idf), len(self.columns)))

        # Fail fast (better than raw sklearn error - but doesn't guarantee that result
        # will be meaningful if input is too small)
        if len(filtered_idf) < 2:
            raise NotEnoughDataError()

        # Stack requested columns into a matrix
        data = np.stack([filtered_idf.float_col(column) for column in self.columns], axis=1)

        # Fit PCA
        rescaler = preprocessing.StandardScaler().fit(data)
        pca = decomposition.PCA().fit(rescaler.transform(data))

        # Project data
        projections = pca.transform(rescaler.transform(data))[:, :self.projection_dimension].T

        # Aggregate projections on original dataframe
        projections_dict = {"pc{}".format(i + 1): proj for i, proj in enumerate(projections)}
        i = len(projections_dict)
        while i < self.projection_dimension:
            projections_dict["pc{}".format(i + 1)] = [None] * len(filtered_idf)
            i = i + 1

        idf_with_projections = idf.extend(projections_dict, prefix=self.input_data_prefix, align_on=filtered_idf)

        # Perform inner computation on aggregated dataframe
        projection_computation_result = self.projection_computation.apply_safe(idf_with_projections, ctx)

        return {
            "type": self.get_type(),
            "eigenvalues": pca.explained_variance_.tolist(),
            "eigenvectors": pca.components_.tolist(),
            "projectionComputationResult": projection_computation_result
        }
