import dataiku
from dataiku.runnables import Runnable, utils

import dku_nvidia.nim_operator as nim_operator
import dku_nvidia.nim_services as nim_services

from dku_nvidia.utils import (
    configure_kubeconfig,
    get_helm_cmd
)
from dku_nvidia.exceptions import (
    ClusterNotRunningError,
    ActionNotDefinedError
)

import os
import time
import shutil
import subprocess
import requests
import json


class MyRunnable(Runnable):
    """The base interface for a Python runnable"""

    def __init__(self, project_key, config, plugin_config):
        """
        :param project_key: the project in which the runnable executes
        :param config: the dict of the configuration of the object
        :param plugin_config: contains the plugin settings
        """
        self.project_key = project_key
        self.config = config
        self.plugin_config = plugin_config
        
    def get_progress_target(self):
        """
        If the runnable will return some progress info, have this function return a tuple of 
        (target, unit) where unit is one of: SIZE, FILES, RECORDS, NONE
        """
        return None

    def run(self, progress_callback):
        """
        Performs the action selected in the 'macro_action' parameter.
        TODO: perform validation on the input parameters
        """
        # Unpack plugin config
        ## Common parameters
        cluster_id = self.config.get("cluster_id", "")
        macro_action = self.config.get("macro_action")
        
        ## NIM Operator and GPU Operator parameters
        gpu_operator_namespace = self.config.get("gpu_operator_namespace")
        gpu_operator_version = self.config.get("gpu_operator_version")
        nim_operator_namespace = self.config.get("nim_operator_namespace")
        nim_operator_version = self.config.get("nim_operator_version")
        
        ## NIM Services parameters
        nim_service_name = self.config.get("nim_service_name")
        nim_services_namespace = self.config.get("nim_services_namespace")
        nim_image_tag = self.config.get("nim_image_tag")
        storage_class = self.config.get("storage_class")
        volume_size = self.config.get("volume_size")
        
        use_autoscaler = self.config.get("use_autoscaler")
        replicas = self.config.get("replicas")
        min_replicas = self.config.get("min_replicas")
        max_replicas = self.config.get("max_replicas")
        gpu_cache_usage_perc = self.config.get("gpu_cache_usage_perc") / 100
        additional_env_vars = extract_env_vars(self.config.get("nim_environment", {}).get("additional_env_vars", []))

        num_gpus = self.config.get("num_gpus")
        node_selector = self.config.get("node_selector")
        
        exposition_mode = self.config.get("exposition_mode")
        ingress_host = self.config.get("ingress_host", "")
        ingress_path = self.config.get("ingress_path", "")
        
        nim_container_registry_host = self.config.get("nim_services_authentication", {}).get("nim_container_registry_host")
        nim_container_registry_username = self.config.get("nim_services_authentication", {}).get("nim_container_registry_username")
        nim_container_registry_api_key = self.config.get("nim_services_authentication", {}).get("nim_container_registry_api_key")
        
        override_nim_repository = self.config.get("nim_services_authentication", {}).get("override_nim_repository", False)
        nim_repository_protocol = self.config.get("nim_services_authentication", {}).get("nim_repository_protocol")
        nim_repository_host = self.config.get("nim_services_authentication", {}).get("nim_repository_host")
        nim_repository_api_key = self.config.get("nim_services_authentication", {}).get("nim_repository_api_key")
        
        # Set KUBECONFIG environment variable and retrieve path to Helm command      
        configure_kubeconfig(cluster_id)
        helm = get_helm_cmd()
        
        # Check the cluster is running
        try:
            r = subprocess.run(["kubectl", "version"], capture_output=True) 
            r.check_returncode()
        except subprocess.CalledProcessError as err:
            raise ClusterNotRunningError(f"The Kubernetes cluster {cluster_id} is unreachable or not running.") from err
        
        # Perform macro action       
        if macro_action == "nim_operator_add":
            r = nim_operator.add(
                helm,
                gpu_operator_namespace,
                gpu_operator_version,
                nim_operator_namespace,
                nim_operator_version
            )
        elif macro_action == "nim_operator_list":
            _, r = nim_operator.list(
                helm,
                gpu_operator_namespace,
                nim_operator_namespace
            )
        elif macro_action == "nim_operator_rm":
            r = nim_operator.rm(
                helm,
                gpu_operator_namespace,
                nim_operator_namespace
            )
        elif macro_action == "nim_services_add":
            r = nim_services.add(
                helm,
                nim_services_namespace,
                nim_image_tag,
                storage_class,
                volume_size,
                use_autoscaler,
                replicas,
                min_replicas,
                max_replicas,
                gpu_cache_usage_perc,
                num_gpus,
                additional_env_vars,
                node_selector,
                exposition_mode,
                ingress_host,
                ingress_path,
                nim_container_registry_host,
                nim_container_registry_username,
                nim_container_registry_api_key,
                override_nim_repository,
                nim_repository_protocol,
                nim_repository_host,
                nim_repository_api_key
            )
        elif macro_action == "nim_services_list":
            r = nim_services.list(
                helm,
                nim_services_namespace
            )
        elif macro_action == "nim_services_rm":
            r = nim_services.rm(
                helm,
                nim_services_namespace,
                nim_service_name
            )
        else:
            raise ActionNotDefinedError("Macro action not selected; please select a macro action from the dropdown.")
        
        return str(r)

    
#### ------------------------------------------- ######
####    Parameter validations helper functions   ######
#### ------------------------------------------- ######
    
def extract_env_vars(raw_additional_env_vars):
    """
    Input is a list of dicts, e.g.
    [{'from': 'my_key', '$touched.it.from': True, 'to': 'my_value', '$touched.it.to': True}, {'$touched.it.from': True}]
    
    Unfortunately, the list can include partial env vars (e.g. a key but no value), and also includes some
    extraneous info, so we want to clean this up.
    
    Returns a list of dicts, i.e.
    [{"key": keyval, "value": val}]
    """
    env_vars = []
    
    for var in raw_additional_env_vars:
        if var.get("from") and var.get("to"):
            env_vars.append({
                "key": var.get("from"),
                "value": var.get("to") 
            })
    
    return env_vars
