from typing import List
from typing import Dict

import logging
import os

from sentence_transformers import CrossEncoder

import torch

from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.types import ProcessSingleRerankingCommand, ProcessSingleRerankingResponse

logger = logging.getLogger(__name__)

class ModelPipelineRerankingTransformers(ModelPipelineBatching[ProcessSingleRerankingCommand, ProcessSingleRerankingResponse]):
    def __init__(self, model_name_or_path: str, trust_remote_code: bool, batch_size: int):
        super().__init__(batch_size)
        model = CrossEncoder(
            model_name_or_path,
            automodel_args={"torch_dtype": "auto"},
            trust_remote_code=trust_remote_code,
        )

        model = model.to("cuda") if torch.cuda.is_available() else model.to('cpu') # Do this or you can get NaN scores
        model.eval()
        self._model = model

    def _get_inputs(self, requests: List[ProcessSingleRerankingCommand]) -> List:
        return requests

    def _get_params(self, request: ProcessSingleRerankingCommand) -> Dict:
        return {}

    def _run_inference(self, inputs: List[ProcessSingleRerankingCommand], params: Dict) -> List[ProcessSingleRerankingResponse]:
        responses: List[ProcessSingleRerankingResponse] = []
        # We do inference one request by one as there can be differences in the scoring results
        for reranking_request in inputs:
            response: ProcessSingleRerankingResponse
            try:
                query_text = os.linesep.join(part["text"] for part in reranking_request["queryParts"] if part["type"] == "TEXT")
                docs = [os.linesep.join(part["text"] for part in doc["parts"] if part["type"] == "TEXT") for doc in reranking_request["documents"]]

                # We rank the paired documents
                rankings = self._model.rank(query_text, docs)

                # rankings is already sorted by relevance score, the first ones being the most relevant
                response = {
                    "ok": True,
                    "documents": [ { "index" : ranked_doc['corpus_id'], "relevanceScore": ranked_doc['score'] } for ranked_doc in rankings]
                }
            except Exception as e:
                logger.exception("Error during reranking inference", e)
                response = {
                    "ok": False,
                    "errorMessage": str(e),
                    "documents": [],
                }
            responses.append(response)

        return responses

    def _parse_response(self, response: ProcessSingleRerankingResponse, request: ProcessSingleRerankingCommand) -> ProcessSingleRerankingResponse:
        return response