import logging

from dataiku.base.folder_context import build_folder_context
from dataiku.core import doctor_constants
from dataiku.core.doctor_constants import DEFAULT_PERMUTATION_IMPORTANCE_ITERATIONS
from dataiku.doctor.timeseries.models import TimeseriesForecastingAlgorithm
from dataiku.doctor.timeseries.perf.model_perf import TimeseriesModelScorer
from dataiku.doctor.timeseries.preparation.preprocessing import TimeseriesPreprocessing, get_windows_list, add_rolling_windows_for_training, get_shift_map
from dataiku.doctor.timeseries.train.training_handler import resample_for_training
from dataiku.doctor.timeseries.utils import filter_dataframe_by_encoded_identifiers
from dataiku.doctor.timeseries.utils.permutation_importance_computer import PermutationImportanceComputer
from dataiku.doctor.utils import get_filtered_features
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.model_io import from_pkl
from dataiku.doctor.utils.split import load_train_set

logger = logging.getLogger(__name__)


def compute_post_train_permutation_importance(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                              modellike_folder=None, computation_parameters=None, postcompute_folder=None,
                                              train_split_desc=None, train_split_folder=None):
    model_folder_context = build_folder_context(model_folder)
    preprocessing_folder_context = build_folder_context(preprocessing_folder)
    preprocessing_params = preprocessing_folder_context.read_json("rpreprocessing_params.json")
    listener = ProgressListener()
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    algorithm = TimeseriesForecastingAlgorithm.build(modeling_params["algorithm"])


    if not PermutationImportanceComputer.supports_permutation_importance(algorithm, preprocessing_params):
        return "nok"

    full_timeseries_preprocessing = TimeseriesPreprocessing(preprocessing_folder_context, core_params, preprocessing_params, modeling_params, listener, True)

    split_folder_context = build_folder_context(split_folder)
    full_df = load_train_set(core_params, preprocessing_params, split_desc, "full",
                             split_folder_context, use_diagnostics=False)

    full_df = resample_for_training(
        full_df,
        split_desc["schema"],
        preprocessing_params[doctor_constants.TIMESERIES_SAMPLING],
        core_params,
        preprocessing_params,
        algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_past_only_external_features(),
        False,
        )

    timeseries_identifier_columns = core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]
    use_only_generated_features = modeling_params.get("isShiftWindowsCompatible", False)
    windows_list = get_windows_list(preprocessing_params) if modeling_params.get("isShiftWindowsCompatible", False) else []
    full_df = add_rolling_windows_for_training(full_df, core_params, windows_list, preprocessing_params, preprocessing_folder_context)
    trained_models_identifiers = list(model_folder_context.read_json("forecasts.json.gz")["perTimeseries"].keys())
    full_df = filter_dataframe_by_encoded_identifiers(full_df, timeseries_identifier_columns, trained_models_identifiers)

    full_generated_features_mappings = {ts_id: pipeline.generated_features_mapping for ts_id, pipeline in full_timeseries_preprocessing.pipeline_by_timeseries.items()}
    shift_map = get_shift_map(preprocessing_params, full_generated_features_mappings)

    clf = from_pkl(model_folder_context)

    include_roles = ["INPUT", "INPUT_PAST_ONLY"] if algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_past_only_external_features() else ["INPUT"]
    columns = get_filtered_features(preprocessing_params, include_roles=include_roles)

    if not columns:
        logger.error("Cannot compute permutation importance for model without external features")
        return "nok"

    metrics_params = modeling_params["metrics"]

    model_scorer = TimeseriesModelScorer.build(core_params, metrics_params, True)

    n_iterations = DEFAULT_PERMUTATION_IMPORTANCE_ITERATIONS
    per_identifier = len(timeseries_identifier_columns) > 0 and modeling_params.get("perIdentifierPermutationImportance", False)

    if "perIdentifierPermutationImportance" in computation_parameters:
        per_identifier = computation_parameters["perIdentifierPermutationImportance"]

    if "permutationImportanceIterations" in computation_parameters:
        n_iterations = computation_parameters["permutationImportanceIterations"]


    computer = PermutationImportanceComputer(model_scorer, model_folder_context, algorithm, core_params, full_timeseries_preprocessing,
                                             timeseries_identifier_columns, clf, columns, shift_map, use_only_generated_features)
    computer.compute_permutation_importance(full_df, per_identifier, n_iterations)

    return "ok"
