import pandas as pd
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

from dataiku.llm.types import CompletionSettings
from dataikuapi.dss.llm_tracing import SpanBuilder


def mean_or_none(pd_series):
    mean = pd_series.mean(skipna=True)
    return None if pd.isna(mean) else mean


def get_llm_args(completion_settings: CompletionSettings):
    llm_args = {}
    if "temperature" in completion_settings:
        llm_args["temperature"] = completion_settings["temperature"]
    if "maxOutputTokens" in completion_settings:
        llm_args["max_tokens"] = completion_settings["maxOutputTokens"]
    if "topK" in completion_settings:
        llm_args["top_k"] = completion_settings["topK"]
    if "topP" in completion_settings:
        llm_args["top_p"] = completion_settings["topP"]
    return llm_args


class CompletionTraceHandler(BaseCallbackHandler):
    span: SpanBuilder

    def __init__(self, span: SpanBuilder):
        self.span = span

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        if response.llm_output is not None and "lastTrace" in response.llm_output:
            self.span.append_trace(response.llm_output["lastTrace"])
