from transformers import AutoConfig, pipeline

from dataiku.huggingface.pipeline_summarization import ModelPipelineSummarization
from dataiku.huggingface.env_collector import extract_model_architecture


class ModelPipelineSummarizationGeneric(ModelPipelineSummarization):

    def __init__(self, model_path, model_kwargs, batch_size):
        # neither of the two generic summarisation models support automatic device mapping
        self.task = pipeline("summarization", model=model_path, model_kwargs=model_kwargs)
        tokenizer = getattr(self.task, "tokenizer", None)
        config = AutoConfig.from_pretrained(model_path)
        context_length = getattr(config, "max_position_embeddings", None)
        super().__init__(tokenizer, context_length, batch_size)
        self.model_tracking_data["model_architecture"] = extract_model_architecture(config)

    def _run_summarization(self, texts, **kwargs):
        return self.task(texts, **kwargs)
