import logging
import sys
import pandas as pd
from tdigest import TDigest

from dataiku import Dataset
from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.doctor.causal.score.scoring_handler import CausalPredictionScoringHandler
from dataiku.doctor.causal.utils.misc import check_causal_prediction_type
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_recipe_desc, GpuSupportingCapability, log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.model_io import from_pkl
from dataiku.doctor.utils.scoring_recipe_utils import dataframe_iterator
from dataiku.doctor.utils.scoring_recipe_utils import add_fmi_metadata
from dataiku.doctor.utils.scoring_recipe_utils import add_output_model_metadata
from dataiku.doctor.utils.scoring_recipe_utils import add_prediction_time_metadata
from dataiku.doctor.utils.scoring_recipe_utils import get_input_parameters
from dataiku.doctor.utils.scoring_recipe_utils import smmd_colnames

logger = logging.getLogger("causal_scoring_recipe")


def main(model_folder, input_dataset_smartname, output_dataset_smartname, recipe_desc, script,
         preparation_output_schema, fmi):
    model_folder_context = build_folder_context(model_folder)
    input_dataset, core_params, feature_preproc, names, dtypes, parse_date_columns = \
        get_input_parameters(model_folder_context, input_dataset_smartname, preparation_output_schema, script)
    prediction_type = core_params["prediction_type"]
    check_causal_prediction_type(prediction_type)

    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    collector_data = model_folder_context.read_json("collector_data.json")
    preprocessing_params = model_folder_context.read_json("rpreprocessing_params.json")

    recipe_gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    log_nvidia_smi_if_use_gpu(gpu_config=recipe_gpu_config)
    GpuSupportingCapability.init_cuda_visible_devices(recipe_gpu_config["params"]["gpuList"])

    dku_causal_model = from_pkl(model_folder_context, "causal_model.pkl")

    propensity_model_filename = "propensity_model.pkl"
    if recipe_desc["computePropensity"] and model_folder_context.isfile(propensity_model_filename):
        propensity_model = from_pkl(model_folder_context, propensity_model_filename)
    else:
        propensity_model = None
    scoring_handler = CausalPredictionScoringHandler(core_params, preprocessing_params, modeling_params, dku_causal_model,
                                                     propensity_model, collector_data, model_folder_context)
    output_dataset = Dataset(output_dataset_smartname)
    is_multi_value_treatment = core_params["enable_multi_treatment"] and len(core_params["treatment_values"]) > 2
    assign_treatment = recipe_desc["assignTreatment"] and not is_multi_value_treatment
    if assign_treatment and recipe_desc["treatmentAssignmentMode"] == doctor_constants.SAMPLE_RATIO_EXACT:
        # Single-pass scoring (when no batch_size is passed, the whole dataframe is retrieved in the first iteration)
        input_df, input_df_copy_unnormalized = next(
            dataframe_iterator(
                input_dataset, names, dtypes, parse_date_columns, feature_preproc,
                float_precision="round_trip", batch_size=None,
            )
        )
        pred_df = scoring_handler.score(input_df, assign_treatment=assign_treatment, treatment_ratio=recipe_desc["treatmentRatio"])
        if recipe_desc.get("outputModelMetadata", False):
            # add the "prediction time" and reorder the output model metadata Columns
            add_output_model_metadata(pred_df, fmi)
            ordered_columns = [col for col in pred_df.columns if col not in smmd_colnames] + smmd_colnames
            pred_df = pred_df[ordered_columns]
        if recipe_desc.get("filterInputColumns", False):
            clean_kept_columns = [c for c in recipe_desc["keptInputColumns"] if c not in pred_df.columns]
        else:
            clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in pred_df.columns]
        output_df = pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)

        logger.info("Starting writer, scoring full dataset")
        with output_dataset.get_writer() as writer:
            writer.write_dataframe(output_df)
            logger.info("Output df written")
    else:
        # Batch scoring
        batch_size = recipe_desc.get("pythonBatchSize", 100000)
        threshold = None
        if assign_treatment:
            if recipe_desc["treatmentAssignmentMode"] == doctor_constants.SAMPLE_RATIO_APPROX:
                t_digest = TDigest()

                logger.info("Starting predicted effect threshold computation (approximate, using t-digest)")
                for input_df, _ in dataframe_iterator(
                    input_dataset, names, dtypes, parse_date_columns, preprocessing_params["per_feature"],
                    batch_size=batch_size, float_precision="round_trip"
                ):
                    pred_df = scoring_handler.score(input_df, assign_treatment=False)
                    t_digest.batch_update(pred_df["predicted_effect"])
                logger.info("End of threshold computation")
                threshold = t_digest.percentile(100 * (1 - recipe_desc["treatmentRatio"]))
                logger.info("Computed threshold: {}".format(threshold))
            elif recipe_desc["treatmentAssignmentMode"] == doctor_constants.THRESHOLD:
                threshold = recipe_desc["assignmentThreshold"]
                logger.info("User specified threshold: {}".format(threshold))
        else:
            logger.info("No treatment assignment")

        def output_generator(assign_treatment=False, threshold=None):
            for input_df, input_df_copy_unnormalized in dataframe_iterator(
                input_dataset, names, dtypes, parse_date_columns, preprocessing_params["per_feature"],
                batch_size=batch_size, float_precision="round_trip"
            ):
                pred_df = scoring_handler.score(input_df, assign_treatment=False)

                if recipe_desc.get("outputModelMetadata", False):
                    add_fmi_metadata(pred_df, fmi)

                if assign_treatment:
                    pred_df["treatment_recommended"] = pred_df["predicted_effect"] > threshold

                if recipe_desc.get("filterInputColumns", False):
                    clean_kept_columns = [c for c in recipe_desc["keptInputColumns"] if c not in pred_df.columns]
                else:
                    clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in pred_df.columns]

                if recipe_desc.get("outputModelMetadata", False):
                    # add the "prediction time" and reorder the output model metadata Columns
                    add_prediction_time_metadata(pred_df)
                    ordered_columns = [col for col in pred_df.columns if col not in smmd_colnames] + smmd_colnames
                    pred_df = pred_df[ordered_columns]
                yield pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)

        with output_dataset.get_writer() as writer:
            logger.info("Starting to iterate, scoring by batches")
            assign_treatment_batch = assign_treatment and recipe_desc["treatmentAssignmentMode"] in {doctor_constants.THRESHOLD, doctor_constants.SAMPLE_RATIO_APPROX}
            for output_df in output_generator(assign_treatment_batch, threshold):
                logger.info("Generator generated a df %s" % str(output_df.shape))
                writer.write_dataframe(output_df)
                logger.info("Output df written")


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        main(
            sys.argv[1],
            sys.argv[2],
            sys.argv[3],
            dkujson.load_from_filepath(sys.argv[4]),
            dkujson.load_from_filepath(sys.argv[5]),
            dkujson.load_from_filepath(sys.argv[6]),
            sys.argv[7]
        )
