import datetime
import dataiku
from dataiku.runnables import Runnable

def get_epochtime_ms():
    return int((datetime.datetime.utcnow() - datetime.datetime(1970, 1, 1)).total_seconds() * 1000)

class MyRunnable(Runnable):
    def __init__(self, project_key, config, plugin_config):
        self.project_key = project_key
        self.config = config
        self.client = dataiku.api_client()
        self.queue_sessions = [] # list of next sessions among all queues

    def get_progress_target(self):
        return None

    def run(self, progress_callback):
        run_type = self.config.get("runScope", "")
        dry_run = self.config.get('simulate', False)
        project = self.client.get_project(self.project_key)
        all_queues = project.list_mltask_queues()
        queues_to_run = all_queues
        
        if run_type == "MLTASK":
            queues_to_run = filter(lambda queue:  
                queue.get('mlTaskLoc').get('mlTaskId') == self.config.get("mlTaskId", "") and
                queue.get('mlTaskLoc').get('analysisId') == self.config.get("analysisId", "")
            , all_queues)

        for q in queues_to_run:
            loc = q.get('mlTaskLoc')
            self.train_queue(project, loc.get("analysisId"), loc.get("mlTaskId"), dry_run)

        html = ""
        queue_count = 0
        
        # get mltask corresponding to each queue
        for ml_task in project.list_ml_tasks()["mlTasks"]:
            if any(q.get("sessionId") is not None and q.get("analysisId") == ml_task.get("analysisId") and q.get("mlTaskId") == ml_task.get("mlTaskId") for q in self.queue_sessions):
                queue_count += 1
                html += '<li><a href="/projects/{}/analysis/{}/ml/{}/{}/list/results" target="_blank">{}</a></li>'.format(
                    self.project_key,
                    ml_task.get("analysisId"),
                    "c" if ml_task.get("taskType") == "CLUSTERING" else "p",
                    ml_task.get("mlTaskId"),
                    "{task_name} ({analysis_name})".format(task_name=ml_task.get("mlTaskName"), analysis_name=ml_task.get("analysisName"))
                )

        if queue_count == 0:
            html = "No queues to run."
        else:
            description = "Queues in the following tasks "
            if dry_run:
                description += "will be"
            else:
                description += "have been"
            html = "{} started:<ul>{}</ul>".format(description, html)
                
        return html
    
    def train_queue(self, project, analysisId, mlTaskId, dry_run):
        if dry_run:
            self.queue_sessions.append({
                "analysisId": analysisId,
                "mlTaskId": mlTaskId,
                "sessionId": 0
            })
        else: 
            ml_task = project.get_ml_task(analysisId, mlTaskId)
            next_session = ml_task.train_queue()
            if next_session is not None:
                self.queue_sessions.append(next_session)