import logging
import numpy as np

from dataiku.base import remoterun
from dataiku.base.gpu_utils import log_nvidia_smi
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType

logger = logging.getLogger(__name__)


# Mirrors structure of Params in com.dataiku.dip.analysis.model.core.GpuConfig
# must be kept up to date with any changes in the java
def get_default_gpu_params():
    return {
        "useGpu": True,  # in the java this is false default, for new tasks, but true here to enable gpu by default in the case it is missing
        "gpuList": [0],
        "perGPUMemoryFraction": 0.5,
        "gpuAllowGrowth": False
    }


def get_default_gpu_config():
    return {
        "params": get_default_gpu_params(),
        "disabledCapabilities": []
    }


def get_gpu_config_from_core_params(core_params):
    return core_params.get("executionParams", {}).get("gpuConfig", get_default_gpu_config())


def get_gpu_config_from_recipe_desc(recipe_desc):
    return recipe_desc.get("gpuConfig", get_default_gpu_config())


def log_nvidia_smi_if_use_gpu(core_params=None, recipe_desc=None, gpu_config=None):
    if gpu_config is not None:
        pass
    elif core_params is not None:
        gpu_config = get_gpu_config_from_core_params(core_params)
    elif recipe_desc is not None:
        gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    else:
        return
    if gpu_config["params"]["useGpu"]:
        log_nvidia_smi(True, logger)
        log_nvidia_smi(False, logger)
    else:
        logger.info("Not logging `nvidia-smi`: `useGPU` is False")


def get_single_gpu_id_from_gpu_device(device_string):
    if not device_string.startswith("cuda:"):
        raise Exception("Unexpected device string: {} - String does not start with 'cuda:'".format(device_string))
    if "," in device_string:
        raise Exception("Unexpected device string: {} - String appears to contain more than one device, expected single.".format(device_string))
    else:
        return int(device_string[len('cuda:'):])


class GpuSupportingCapability:

    DIAGNOSTIC_ID_RUNTIME_NO_GPU_FOUND = "ML_DIAGNOSTICS_RUNTIME--NO_GPU_FOUND"

    @classmethod
    def is_gpu_available(cls, gpu_config=None):
        raise NotImplementedError()

    @staticmethod
    def name():
        raise NotImplementedError()

    @staticmethod
    def display_name():
        return "Gpu capability"

    # This function should be kept in sync with the matching implementation in
    # com.dataiku.dip.analysis.model.core.GpuConfig.shouldUseGpu
    @classmethod
    def should_use_gpu(cls, gpu_config, allow_cpu_fallback=True):
        """
        Caution: This function calls dataiku.doctor.utils.gpu_execution.GpuSupportingCapability.is_gpu_available, which will often
        initialise a cuda device - this may be undesirable
        """
        if cls._is_gpu_enabled_for_capability(gpu_config):
            is_available = cls.is_gpu_available(gpu_config)

            if not is_available:  # gpu requested but not accessible
                if not allow_cpu_fallback:
                    raise ValueError("GPU usage was requested for {} but no GPU is accessible. "
                                     "Please contact your administrator or disable GPU usage".format(cls.display_name()))

                diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_RUNTIME,
                                          "Requested GPU but no device is available, using CPU",
                                          diagnostic_id=cls.DIAGNOSTIC_ID_RUNTIME_NO_GPU_FOUND)
                logger.warning("Defaulting to CPU for {}, as no GPU was detected".format(cls.display_name()))
            return is_available
        return False

    @classmethod
    def _is_gpu_enabled_for_capability(cls, gpu_config):
        if gpu_config.get("params", get_default_gpu_params()).get("useGpu"):
            if cls.name() not in gpu_config.get("disabledCapabilities", []):
                return True
        return False

    @classmethod
    def configure_predictor_gpu_env(cls, gpu_config, is_api_node):
        if is_api_node:
            cls.force_gpu_0(gpu_config)
        else:
            cls._set_visible_devices_if_enabled(gpu_config)

    @classmethod
    def _set_visible_devices_if_enabled(cls, gpu_config):
        if cls._is_gpu_enabled_for_capability(gpu_config):
            gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
            cls.init_cuda_visible_devices(gpu_list)

    @classmethod
    def get_device(cls, gpu_config):
        use_gpu = cls.should_use_gpu(gpu_config)
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")

        if not use_gpu or len(gpu_list) == 0:
            logger.info("Using 'cpu' device for {}".format(cls.display_name()))
            return "cpu"

        gpu_string = "cuda:{}".format(cls._get_device_ids(gpu_list))
        logger.info("Using the following gpu device(s) for {}: '{}' ".format(cls.display_name(), gpu_string))
        return gpu_string

    @classmethod
    def _get_device_ids(cls, gpu_list):
        raise NotImplementedError()

    @classmethod
    def force_gpu_0(cls, gpu_config):
        gpu_config["params"]["gpuList"] = [0]
        if cls.should_use_gpu(gpu_config):
            logger.info("Forcing GPU 0 for {}".format(cls.display_name()))

    @classmethod
    def force_single_gpu(cls, gpu_config):
        if cls.should_use_gpu(gpu_config):
            logger.info("Forcing single gpu for {}".format(cls.display_name()))
            gpu_params = gpu_config["params"]
            if len(gpu_params["gpuList"]) > 0:
                # when forcing single gpu, we select the first item in the gpulist, which is pre-sorted by gpu device index
                gpu_params["gpuList"] = [gpu_params["gpuList"][0]]
            else:
                gpu_params["gpuList"] = [0]

    @classmethod
    def init_cuda_visible_devices(cls, gpu_list):
        """
        Remap here the physical device using a CUDA environment variable. This needs to be done prior to any
        call to `torch.cuda.something`, or anything else that initialises a cuda 'instance', otherwise initialization
        code to communicate with GPUs would have already been run.
        Then using device 'cuda:X' in torch code will be using the Xth device of the provided gpu list.

        E.g. Passing gpu_list = [2,4] and then provided a device string of 'cuda:0' will execute the process on gpu 2

        :param list gpu_list: Physical GPU ids to use (0-based).
           - When the task runs locally, it allows the user leverage specific GPU(s), especially if some GPU(s) are
            already taken by other users.
           - When run on a container, users can only define the number of requested GPUs per container, so the list
            will always be [0 .. number_of_gpus -1]
        """

        new_cuda_visible_devices = ",".join(map(str, gpu_list))
        remoterun.set_dku_env_var_and_sys_env_var("CUDA_VISIBLE_DEVICES", new_cuda_visible_devices)
        logger.info("CUDA_VISIBLE_DEVICES set to '{}' for {}".format(new_cuda_visible_devices, cls.display_name()))

    @classmethod
    def disable_all_cuda_devices(cls):
        remoterun.set_dku_env_var_and_sys_env_var("CUDA_VISIBLE_DEVICES", "-1")
        logger.info("CUDA_VISIBLE_DEVICES set to '-1' for {}".format(cls.display_name()))


class KerasGPUCapability(GpuSupportingCapability):
    @staticmethod
    def name():
        return "KERAS"

    @staticmethod
    def display_name():
        return "Deep Learning"

    @classmethod
    def is_gpu_available(cls, gpu_config=None):
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        KerasGPUCapability.init_cuda_visible_devices(gpu_list)

        from dataiku.doctor.deep_learning import tfcompat
        gpus = tfcompat.list_physical_devices()
        is_cuda_available = len(gpus) > 0
        logger.info("Keras cuda available: {}".format(is_cuda_available))
        return is_cuda_available

    @classmethod
    def _get_device_ids(cls, gpu_list):
        return NotImplementedError("Keras uses a different method of setting device")


class GluonTSMXNetGPUCapability(GpuSupportingCapability):
    @staticmethod
    def name():
        # we use the same name as `GluonTSTorchGPUCapability`, to allow smooth interoperation between mxnet and torch based ts algs,
        # without adding further options for the end user
        return "GLUONTS"

    @staticmethod
    def display_name():
        return "Time Series with MXNet"

    @classmethod
    def is_gpu_available(cls, gpu_config=None):
        # alias to have mxnet support numpy 1.24 and above as mxnet code is frozen (see https://app.shortcut.com/dataiku/story/177201/state-of-numpy-in-dss)
        np.bool = bool
        try:
            import mxnet as mx
            is_cuda_available = mx.context.num_gpus() > 0
            logger.info("MXNet cuda available: {}".format(is_cuda_available))
            return is_cuda_available
        except ModuleNotFoundError as e:
            logger.warning("MXNet not available: {}".format(e.msg))
            return False

    @classmethod
    def _get_device_ids(cls, gpu_list):
        # CUDA_VISIBLE_DEVICES must be set for this capability
        return 0


class XGBOOSTGpuCapability(GpuSupportingCapability):
    @staticmethod
    def name():
        return "XGBOOST"

    @staticmethod
    def display_name():
        return "XGBoost"

    @classmethod
    def is_gpu_available(cls, gpu_config=None):
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        cls.init_cuda_visible_devices(gpu_list)

        # xgboost does not expose a method to check devices, so we try to train using a cuda specific tree method and device and see
        from xgboost import XGBClassifier, core
        model = XGBClassifier(device='cuda', tree_method='gpu_hist')

        is_gpu_available = False
        X = np.array([[1, 2]])
        y = np.array([0])

        try:
            model.fit(X, y)
            is_gpu_available = True
        except core.XGBoostError:
            pass

        logger.info("XGBoost gpu available: {}".format(is_gpu_available))
        return is_gpu_available

    @classmethod
    def _get_device_ids(cls, gpu_list):
        return 0


class TorchGpuCapability(GpuSupportingCapability):
    @classmethod
    def is_gpu_available(cls, gpu_config=None):
        if gpu_config is None:
            gpu_config = {}

        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        cls.init_cuda_visible_devices(gpu_list)

        from torch import cuda
        is_cuda_available = cuda.is_available()
        return is_cuda_available

    @classmethod
    def _get_device_ids(cls, gpu_list):
        # for torch capabilities at this point CUDA_VISIBLE_DEVICES should be already set to gpu_list.
        # therefore we need to return the rank of the gpu to use among the visible devices.
        # by default, use first of visible devices (cuda:0)
        return 0


class GluonTSTorchGPUCapability(TorchGpuCapability):
    @staticmethod
    def name():
        # we use the same name as `GluonTSMXNetGPUCapability`, to allow smooth interoperation between mxnet and torch based ts algs,
        # without adding further options for the end user
        return "GLUONTS"

    @staticmethod
    def display_name():
        return "Time Series with Torch"

    @classmethod
    def get_lightning_devices(cls, gpu_config):
        if gpu_config is None:
            gpu_config = {}

        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")

        return [cls._get_device_ids(gpu_list)]


class DeepHubGpuCapability(TorchGpuCapability):
    @staticmethod
    def name():
        return "DEEP_HUB"

    @staticmethod
    def display_name():
        return "Computer Vision"

    @classmethod
    def set_deephub_ctx_gpu_behaviour(cls, per_node_gpu_list, deephub_context):
        if len(per_node_gpu_list) > 0:
            logger.info("Will use the following CUDA devices for {}: {}".format(cls.display_name(), per_node_gpu_list))
            deephub_context.set_cuda_based(True)
        else:
            logger.info("Using 'cpu' device for {}".format(cls.display_name()))
            deephub_context.set_cuda_based(False)

    @classmethod
    def _get_device_ids(cls, gpu_list):
        # not currently in use, deephub has its own logic incorporating a process rank
        raise NotImplementedError()


class DeepNNGpuCapability(TorchGpuCapability):
    @staticmethod
    def name():
        return "DEEP_NN"

    @staticmethod
    def display_name():
        return "Deep Neural Network"


class SentenceEmbeddingGpuCapability(TorchGpuCapability):
    @staticmethod
    def name():
        return "SENTENCE_EMBEDDING"

    @staticmethod
    def display_name():
        return "Text Embedding"
