import logging
import pandas as pd
import sys

from dataiku import Dataset
from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core import doctor_constants
from dataiku.doctor.causal.evaluate.evaluation_handler import CausalPredictionEvaluationHandler
from dataiku.doctor.exception import EmptyDatasetException
from dataiku.doctor.evaluation.base import load_input_dataframe
from dataiku.doctor.utils import normalize_dataframe
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_recipe_desc, GpuSupportingCapability, log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.scoring_recipe_utils import get_input_parameters

logger = logging.getLogger(__name__)


def main(model_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname, recipe_desc,
         script, preparation_output_schema):

    log_nvidia_smi_if_use_gpu(recipe_desc=recipe_desc)
    has_output_dataset = output_dataset_smartname is not None and len(output_dataset_smartname) > 0
    has_metrics_dataset = metrics_dataset_smartname is not None and len(metrics_dataset_smartname) > 0
    if not has_metrics_dataset:
        logger.info("Will only score and compute statistics")

    model_folder_context = build_folder_context(model_folder)
    input_dataset, core_params, feature_preproc, names, dtypes, parse_date_columns = \
        get_input_parameters(model_folder_context, input_dataset_smartname, preparation_output_schema, script)
    logger.info("Scoring data")

    input_df = load_input_dataframe(
        input_dataset=input_dataset,
        sampling=recipe_desc.get('selection', {"samplingMethod": "FULL"}),
        columns=names,
        dtypes=dtypes,
        parse_date_columns=parse_date_columns,
    )
    if input_df.empty:
        raise EmptyDatasetException("The evaluation dataset can not be empty. Check the input dataset or the recipe sampling configuration.")

    input_df_copy_unnormalized = input_df.copy()

    # TODO @causal API logs dataset normalization here

    logger.info("Normalizing dataframe of shape: {}".format(input_df.shape))
    normalize_dataframe(input_df, feature_preproc)
    for col in input_df:
        logger.info("Normalized column: {} -> {}".format(col, input_df[col].dtype))

    # TODO @causal ML Diagnostics code here (See PR#21119)

    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    collector_data = model_folder_context.read_json("collector_data.json")
    preprocessing_params = model_folder_context.read_json("rpreprocessing_params.json")
    recipe_gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    GpuSupportingCapability.init_cuda_visible_devices(recipe_gpu_config["params"]["gpuList"])

    output_metrics = recipe_desc["metrics"]

    from dataiku.doctor.utils.model_io import from_pkl
    dku_causal_model = from_pkl(model_folder_context, "causal_model.pkl")

    propensity_model_filename = "propensity_model.pkl"
    if ((recipe_desc["computePropensity"] or modeling_params["metrics"].get("causalWeighting") == doctor_constants.INVERSE_PROPENSITY)
            and model_folder_context.isfile(propensity_model_filename)):
        logger.info("Loading propensity model")
        propensity_model = from_pkl(model_folder_context, propensity_model_filename)
    else:
        propensity_model = None

    evaluation_handler = CausalPredictionEvaluationHandler(
        core_params, preprocessing_params, modeling_params, dku_causal_model, propensity_model, collector_data, model_folder_context
    )
    pred_df, metrics_df = evaluation_handler.evaluate(input_df, output_metrics, compute_metrics=has_metrics_dataset)

    if recipe_desc.get("filterInputColumns", False):
        clean_kept_columns = [c for c in recipe_desc["keptInputColumns"] if c not in pred_df.columns]
    else:
        clean_kept_columns = [c for c in input_df_copy_unnormalized.columns]

    output_df = pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)

    # write scored data
    if has_output_dataset:
        output_dataset = Dataset(output_dataset_smartname)
        logger.info("writing scored data")
        output_dataset.write_from_dataframe(output_df)

    # write metrics dataset
    if has_metrics_dataset:
        metrics_dataset = Dataset(metrics_dataset_smartname)
        logger.info("writing metrics data")
        metrics_dataset.write_from_dataframe(metrics_df)


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        main(
            sys.argv[1],
            sys.argv[2],
            sys.argv[3],
            sys.argv[4],
            dkujson.load_from_filepath(sys.argv[5]),
            dkujson.load_from_filepath(sys.argv[6]),
            dkujson.load_from_filepath(sys.argv[7]),
        )
