from dataiku.base.folder_context import get_partitions_fmi_folder_contexts
from dataiku.core import doctor_constants
from dataiku.doctor.utils.gpu_execution import GluonTSMXNetGPUCapability
from dataiku.doctor.utils.model_io import from_pkl
from dataiku.doctor.utils.scoring_recipe_utils import dataframe_iterator
from dataiku.doctor.utils.scoring_recipe_utils import get_input_parameters
from dataiku.doctor.utils.scoring_recipe_utils import is_partition_dispatch


def _load_model(model_folder_context, recipe_gpu_config):
    resolved_modeling_params = model_folder_context.read_json("actual_params.json")["resolved"]
    preprocessing_params = model_folder_context.read_json( "rpreprocessing_params.json")
    core_params = model_folder_context.read_json( "core_params.json")
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    metrics_params = modeling_params["metrics"]

    version_info = model_folder_context.read_json("version_info.json")
    source_dss_version = int(version_info.get("trainedWithDSSConfVersion", 0))

    if source_dss_version >= 12400:
        # Currently in dss we use the CUDA_VISIBLE_DEVICES environment variable to enforce gpu selection
        # To do this, we pass cuda:0 as the device string, and then set CUDA_VISIBLE_DEVICES=X, where X is our desired gpu(s) id(s)
        # Under the hood, cuda:0 will then execute on the first item in CUDA_VISIBLE_DEVICES
        # e.g. if CUDA_VISIBLE_DEVICES=[2,4], cuda:0 will execute on physical device id 2 (and cuda:1 would be physical device id 4)

        # If you try to use cuda:4 here, you will receive a 'RuntimeError: CUDA error: invalid device ordinal' (or similar), due to the underlying remapping

        # Timeseries models trained in dss before 12.4 would use cuda:X, were X was the desired gpu id.
        # If you unpickle a model with cuda:X, where X>0, after setting CUDA_VISIBLE_DEVICES=X, this will fail
        # therefore, we only set CUDA_VISIBLE_DEVICES on newer than 12.4 models

        GluonTSMXNetGPUCapability.init_cuda_visible_devices(recipe_gpu_config["params"]["gpuList"])

    clf = from_pkl(model_folder_context)

    # initialize is required for models that were not serialized with all parameters from release 11.2 to 12.0
    clf.initialize(core_params, resolved_modeling_params)
    return model_folder_context, preprocessing_params, clf, modeling_params, resolved_modeling_params, metrics_params


def load_model_partitions(model_folder_context, recipe_gpu_config, base_fmi):
    # Prepare partitioned models if in partition dispatch mode, meaning model_folder is the base mode
    partition_dispatch = is_partition_dispatch(model_folder_context)
    if partition_dispatch:
        partitions = {}
        partitions_fmis = {}
        partitions_fmi_folder_contexts = get_partitions_fmi_folder_contexts(base_fmi)
        for partition_name, partition_fmi_folder_contexts in partitions_fmi_folder_contexts.items():
            partition_model_folder_context = partition_fmi_folder_contexts.model_folder_context
            partitions[partition_name] = _load_model(partition_model_folder_context, recipe_gpu_config)
            partitions_fmis[partition_name] = partition_fmi_folder_contexts.fmi
    else:
        partitions = {"NP": _load_model(model_folder_context, recipe_gpu_config)}
        partitions_fmis = {"NP": base_fmi}
    return partition_dispatch, partitions, partitions_fmis


def get_input_df_and_parameters(model_folder_context, input_dataset_smartname, recipe_desc, preparation_output_schema, script):
    input_dataset, core_params, feature_preproc, names, dtypes, parse_date_columns = get_input_parameters(
        model_folder_context, input_dataset_smartname, preparation_output_schema, script
    )

    quantiles = core_params[doctor_constants.QUANTILES]
    past_time_steps_to_include = recipe_desc["pastTimestepsToInclude"]

    # When no batch_size is passed, the whole dataframe is retrieved in the first iteration
    input_df, input_df_copy_unnormalized = next(
        dataframe_iterator(
            input_dataset, names, dtypes, parse_date_columns,
            feature_preproc, float_precision="round_trip",
            batch_size=None,
        )
    )

    columns_to_drop = []
    if recipe_desc.get("filterInputColumns", False):
        columns_to_drop = [col for col in input_df if col not in recipe_desc["keptInputColumns"]]

    return input_df, input_df_copy_unnormalized, core_params, quantiles, past_time_steps_to_include, columns_to_drop
