/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.custom;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.CustomLLMConnection;
import com.dataiku.dip.custom.CustomJavaRuntimeDataProvider;
import com.dataiku.dip.custom.PluginSettingsResolver;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.custom.CustomJavaLLMDesc;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.custom.CustomLLMClientAdapter;
import com.dataiku.dip.llm.custom.LoadedJavaLLM;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.plugin.tools.PluginClazzLoader;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.google.gson.JsonObject;
import java.util.List;

public class CustomJavaLLMClientWrapper
extends AbstractLLMClient
implements LLMClient {
    protected final LoadedJavaLLM loaded;
    protected final CustomJavaRuntimeDataProvider service;
    private final CustomLLMConnection connection;
    private LLMQueryRunner queryRunner;
    protected final JsonObject config;
    protected final String projectKey;
    protected final AuthCtx authCtx;
    private CustomLLMClient client;

    public CustomJavaLLMClientWrapper(AuthCtx authCtx, CustomLLMConnection connection, String projectKey, JsonObject config, LoadedJavaLLM loaded, CustomJavaRuntimeDataProvider service, EnrichedLLMStructuredRef enrichedRef) {
        super(enrichedRef);
        this.authCtx = authCtx;
        this.connection = connection;
        this.projectKey = projectKey;
        this.config = config;
        this.loaded = loaded;
        this.service = service;
    }

    public static String getCustomProviderId(String pluginID) {
        return "CustomLLM:" + pluginID;
    }

    private synchronized void initIfNeeded() {
        try {
            if (this.client == null) {
                ClassLoader classLoader = this.service.getClassloader(this.loaded.getType());
                this.client = (CustomLLMClient)new PluginClazzLoader(classLoader).loadClazz(((CustomJavaLLMDesc)this.loaded.getDesc()).clientClass);
                PluginSettingsResolver.ResolvedSettings resolvedSettings = this.service.getExpandedPluginSettings(this.loaded.getType(), this.authCtx, this.projectKey, this.config);
                this.client.init(resolvedSettings);
                CustomLLMConnection.LLMModel model = (CustomLLMConnection.LLMModel)this.connection.getLLMModel(this.getEnrichedRef()).getModel();
                this.queryRunner = new LLMQueryRunner(this.getProviderId(), this.getEnrichedRef(), model, AbstractLLMConnection.HTTPBasedLLMNetworkSettings.of(this.client.getRetrySettings()), CustomJavaLLMClientWrapper::isRetryableException);
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to initialize custom LLM", e);
        }
    }

    private static boolean isRetryableException(Throwable t) {
        return t instanceof CustomLLMClient.RetryableException;
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        this.initIfNeeded();
        return CustomJavaLLMClientWrapper.getCustomProviderId(this.loaded.getOwnerPluginId());
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        this.initIfNeeded();
        return this.client.getMaxParallelism();
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        this.initIfNeeded();
        return this.queryRunner.run(() -> this.client.completeBatch(CustomLLMClientAdapter.toLegacyFormat(queries, settings)));
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        this.initIfNeeded();
        return this.queryRunner.run(() -> this.client.embedBatch(queries, settings));
    }

    @Override
    public List<LLMClient.SimpleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries) throws Exception {
        throw new IllegalArgumentException("Reranking not supported on this LLM");
    }

    @Override
    public boolean supportsStream() {
        this.initIfNeeded();
        return this.client.supportsStream();
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        this.initIfNeeded();
        this.queryRunner.run(() -> {
            this.client.streamComplete(CustomLLMClientAdapter.toLegacyFormat(query, settings), consumer);
            return null;
        });
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        this.initIfNeeded();
        return this.queryRunner.run(() -> this.client.generateImages(query));
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        this.initIfNeeded();
        return this.client.getTotalCRU(usageType, llmRef);
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        return chatMessages;
    }

    @Override
    public void close() throws Exception {
        if (this.client != null) {
            this.client.close();
        }
    }
}

