/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.promptstudio;

import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.llm.promptstudio.PromptComparisonResponse;
import com.dataiku.dip.llm.promptstudio.PromptExecutionEngine;
import com.dataiku.dip.llm.promptstudio.PromptResponse;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.utils.DKULogger;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class PromptComparisonEngine {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private DatasetsDAO datasetsDAO;
    private final AuthCtx authCtx;
    private final PromptStudio promptStudio;
    private final List<String> promptIds;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.promptstudio.engine");

    public PromptComparisonEngine(AuthCtx authCtx, PromptStudio promptStudio, List<String> promptIds) {
        this.authCtx = authCtx;
        this.promptStudio = promptStudio;
        this.promptIds = promptIds;
        SpringUtils.getInstance().autowire((Object)this);
        logger.info((Object)"PCE initialized");
    }

    public PromptComparisonResponse respond() throws Exception {
        PromptComparisonResponse pcr = new PromptComparisonResponse();
        MemTable records = null;
        int i = 0;
        for (String promptId : this.promptIds) {
            logger.info((Object)("Running comparison on prompt: " + promptId));
            PromptStudio.PromptStudioPrompt prompt = this.promptStudio.prompts.stream().filter(p -> p.id.equals(promptId)).findFirst().orElseThrow(() -> new IllegalArgumentException("can't find prompt id"));
            PromptExecutionEngine pee = new PromptExecutionEngine(this.authCtx, this.promptStudio, prompt);
            if (prompt.prompt.promptMode.canUseDataset) {
                if (i == 0) {
                    records = pee.getRecordsForPromptTemplate();
                    pee.setForcedRecordsForPromptTemplate(records);
                } else {
                    pee.setForcedRecordsForPromptTemplate(records);
                }
            }
            PromptResponse response = pee.respond();
            PromptComparisonResponse.PromptComparisonPromptResponse pcpr = new PromptComparisonResponse.PromptComparisonPromptResponse();
            pcpr.promptId = promptId;
            pcpr.llmId = prompt.llmId;
            pcpr.responses = response.responses;
            pcpr.stats = response.stats;
            pcr.mainPromptTemplateInputs.addAll(response.mainPromptTemplateInputs);
            pcr.promptsResponses.add(pcpr);
            ++i;
        }
        if (records != null) {
            for (MemRow row : records.rows) {
                PromptComparisonResponse.PromptComparisonRecord record = new PromptComparisonResponse.PromptComparisonRecord();
                for (PromptStudio.PromptTemplateInput input : pcr.mainPromptTemplateInputs) {
                    MemColumn mc = records.getColumn(input.name);
                    record.mainInputs.put(input.name, row.get(mc));
                }
                pcr.records.add(record);
            }
        }
        return pcr;
    }
}

