import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any

import dataiku
from graphrag.query.llm.base import BaseTextEmbedding


class QueryDataikuEmbeddingLLM(BaseTextEmbedding):
    def __init__(self, embedding_model_id: str):
        self.embedding_model_id = embedding_model_id
        self.client = dataiku.api_client()
        self.project = self.client.get_default_project()
        self.emb_model = self.project.get_llm(self.embedding_model_id)
        self.logger = logging.getLogger(__name__)

    def embed(self, text: str, **kwargs: Any) -> list[float]:
        self.emb_model.new_embeddings()

        emb_query = self.emb_model.new_embeddings()

        emb_query.add_text(text)

        emb_resp = emb_query.execute()
        embeddings = emb_resp.get_embeddings()
        return embeddings[0] if embeddings else []

    async def aembed(self, text: str, **kwargs: Any) -> list[float]:
        loop = asyncio.get_running_loop()
        with ThreadPoolExecutor() as executor:
            emb_resp = await loop.run_in_executor(executor, self.embed, text, **kwargs)
        return emb_resp
