from transformers import pipeline
from typing import List

from dataiku.huggingface.pipeline_text_gen import ModelPipelineTextGeneration
from dataiku.huggingface.types import ProcessSinglePromptCommand


class ModelPipelineTextGenerationGeneric(ModelPipelineTextGeneration):

    def _initialize_pipeline(self, model_path, model_kwargs):
        self.task = pipeline("text-generation", model=model_path, tokenizer=model_path, model_kwargs=model_kwargs)
        self._fixup_tokenizer_if_needed(self.task)

    def _get_inputs(self, requests: List[ProcessSinglePromptCommand]) -> List[str]:
        return [request["prompt"] for request in requests]
