/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.azuremlgeneric;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AzureLLMConnection;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.online.openai.AzureClientBase;
import com.dataiku.dip.llm.online.openai.OpenAIChatAPI;
import com.dataiku.dip.llm.online.openai.OpenAIImageHandling;
import com.dataiku.dip.llm.online.openai.RawOpenAIClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import java.io.IOException;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class AzureLLMClient
extends AzureClientBase<AzureLLMConnection.AzureMLModel> {
    private final AzureLLMConnection connection;
    private final RawOpenAIClient raw;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.azure-llm.client");

    public AzureLLMClient(AuthCtx authCtx, AzureLLMConnection connection, LLMModelHandle<AzureLLMConnection.AzureMLModel> modelHandle, @Nullable String projectKey) throws IOException, DKUSecurityException {
        super(modelHandle.getEnrichedRef(), modelHandle.getModel(), connection.params.networkSettings);
        String key;
        this.connection = connection;
        Optional<String> invalidityReason = ((AzureLLMConnection.AzureMLModel)this.model).getInvalidityReason();
        if (invalidityReason.isPresent()) {
            throw new IllegalArgumentException("Model parameters are invalid: " + invalidityReason.get());
        }
        if (StringUtils.isNotBlank((String)((AzureLLMConnection.AzureMLModel)this.model).key)) {
            logger.debugV("Using endpoint specific key for endpoint with id %s and targetURI %s", new Object[]{((AzureLLMConnection.AzureMLModel)this.model).id, ((AzureLLMConnection.AzureMLModel)this.model).targetURI});
            key = ((AzureLLMConnection.AzureMLModel)this.model).key;
        } else {
            logger.debugV("Using default key for endpoint with id %s and targetURI %s", new Object[]{((AzureLLMConnection.AzureMLModel)this.model).id, ((AzureLLMConnection.AzureMLModel)this.model).targetURI});
            key = connection.params.defaultKey;
        }
        boolean trustAllSSLCertificates = ConnectionUtils.getParamsFromProperties(connection.getDkuProperties()).getBoolParam("dku.connection.llm.trustAllSSLCertificates", false);
        this.raw = RawOpenAIClient.forAzureLLMWithAPIKey(((AzureLLMConnection.AzureMLModel)this.model).targetURI, key, ((AzureLLMConnection.AzureMLModel)this.model).customHeaders, projectKey, this.queryRunner.getHttpClientNetworkSettings(), connection.getProxySettings(), trustAllSSLCertificates, authCtx, connection);
    }

    @Override
    public RawOpenAIClient getRaw() {
        return this.raw;
    }

    @Override
    public boolean isChatModel() {
        return ((AzureLLMConnection.AzureMLModel)this.model).modelType.useChatAPI;
    }

    @Override
    public String getLogPrefix() {
        return "AzureLLM";
    }

    @Override
    public String getAzureModelId() {
        return ((AzureLLMConnection.AzureMLModel)this.model).getId();
    }

    @Override
    public OpenAIImageHandling getImageHandlingMode() {
        throw new Error("Image generation not supported for Azure LLM models");
    }

    @Override
    public OpenAIChatAPI getAPI() {
        return OpenAIChatAPI.CHAT_COMPLETIONS;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "AzureLLM";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }
}

