import logging

import torch
from transformers import AutoConfig, RobertaTokenizerFast, EncoderDecoderModel

from dataiku.huggingface.pipeline_summarization import ModelPipelineSummarization
from dataiku.huggingface.env_collector import extract_model_architecture


class ModelPipelineSummarizationRoberta(ModelPipelineSummarization):

    def __init__(self, model_path, model_kwargs, batch_size):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logging.info("Using device: {}".format(self.device))
        self.tokenizer = RobertaTokenizerFast.from_pretrained(model_path)
        self.model = EncoderDecoderModel.from_pretrained(model_path).to(self.device)
        kwargs = {} if model_kwargs is None else model_kwargs
        config = AutoConfig.from_pretrained(model_path, **kwargs)
        context_length = getattr(config.decoder, "max_position_embeddings", None)
        super().__init__(self.tokenizer, context_length, batch_size)
        self.model_tracking_data["model_architecture"] = extract_model_architecture(config)

    def _run_summarization(self, texts, **kwargs):
        inputs = self.tokenizer(texts, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
        input_ids = inputs.input_ids.to(self.device)
        attention_mask = inputs.attention_mask.to(self.device)
        outputs = self.model.generate(
            input_ids, attention_mask=attention_mask, min_length=kwargs.get("min_length"), max_length=kwargs.get("max_length")
        )
        return [{"summary_text": self.tokenizer.decode(output, skip_special_tokens=True)} for output in outputs]
