from transformers import pipeline
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseZeroShotClassification
from dataiku.huggingface.types import ZeroShotClassificationResponse
from dataiku.huggingface.env_collector import extract_model_architecture


class ModelPipelineZeroShotClassif(ModelPipelineBatching[ProcessSinglePromptCommand, ProcessSinglePromptResponseZeroShotClassification]):

    def __init__(self, model_path, model_kwargs, batch_size):
        super().__init__(batch_size=batch_size)
        self.task = pipeline("zero-shot-classification", model=model_path, model_kwargs=model_kwargs, device_map="auto")
        self.model_tracking_data["used_engine"] = "transformers"
        self.model_tracking_data["task"] = "zero-shot-classification"
        self.model_tracking_data["model_architecture"] = extract_model_architecture(self.task.model.config)

    def _get_inputs(self, requests: List[ProcessSinglePromptCommand]) -> List[str]:
        return [request["prompt"] for request in requests]

    def _run_inference(self, input_texts: List[str], tf_params: Dict) -> List:
        return self.task(input_texts, **tf_params)

    def _get_params(self, request: ProcessSinglePromptCommand) -> Dict:
        params = {
            "candidate_labels": request["settings"]["classLabels"],
        }
        template = request["settings"]["hypothesisTemplate"]
        if template:
            params["hypothesis_template"] = template
        return params

    def _parse_response(self, response: ZeroShotClassificationResponse, request: ProcessSinglePromptCommand) -> ProcessSinglePromptResponseZeroShotClassification:
        # HF doc:
        # Each result comes as a dictionary with the following keys:
        # sequence (str) — The sequence for which this is the output.
        # labels (List[str]) — The labels sorted by order of likelihood.
        # scores (List[float]) — The probabilities for each of the labels.
        return {'classification': response}
