/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.openai;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.analysis.ml.llm.LLMStepwiseTrainingMetrics;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.AzureOpenAIConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIChatCompletionSettingsValidator;
import com.dataiku.dip.llm.online.openai.OpenAIImageHandling;
import com.dataiku.dip.llm.online.openai.OpenAIMode;
import com.dataiku.dip.llm.online.openai.OpenAITextCompletionSettingsValidator;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatChunkResponse;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatChunkResponseAdapter;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQuery;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQueryAdapter;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatResponse;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatResponseAdapter;
import com.dataiku.dip.llm.online.openai.api.OpenAICompletionQuery;
import com.dataiku.dip.llm.online.openai.api.OpenAICompletionQueryAdapter;
import com.dataiku.dip.llm.online.openai.api.OpenAICompletionResponse;
import com.dataiku.dip.llm.online.openai.api.OpenAICompletionResponseAdapter;
import com.dataiku.dip.llm.utils.ImageGenerationUtils;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPost;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class RawOpenAIClient {
    private static final String OPENAI_ENDPOINT_BASE = "https://api.openai.com/v1";
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    private OpenAIMode mode;
    private String azureAPIVersion;
    private String azureImageAPIVersion;
    private boolean useMaxCompletionToken;
    ExternalJSONAPIClient client;
    private static final OpenAITextCompletionSettingsValidator textCompletionValidator = new OpenAITextCompletionSettingsValidator();
    private static final OpenAIChatCompletionSettingsValidator chatCompletionValidator = new OpenAIChatCompletionSettingsValidator();
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.openai.client");

    public static RawOpenAIClient forOpenAI(@Nullable String url, String apiKey, String organizationId, List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders, @Nullable String projectKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean trustAllSSLCertificates, boolean useMaxCompletionToken) {
        HashMap<String, String> headers = new HashMap<String, String>();
        headers.put("Authorization", "Bearer " + apiKey);
        if (StringUtils.isNotBlank((String)organizationId)) {
            headers.put("OpenAI-Organization", organizationId);
        }
        RawOpenAIClient.addCustomHeaders(headers, customHeaders, projectKey);
        String endpoint = StringUtils.isNotBlank((String)url) ? url : OPENAI_ENDPOINT_BASE;
        RawOpenAIClient r = new RawOpenAIClient(endpoint, headers, networkSettings, proxySettings, trustAllSSLCertificates);
        r.mode = OpenAIMode.OPENAI;
        r.useMaxCompletionToken = useMaxCompletionToken;
        return r;
    }

    public static RawOpenAIClient forAzureWithAPIKey(String azureResourceNameOrURL, String azureAPIKey, List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders, @Nullable String projectKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean trustAllSSLCertificates) {
        Object url = RawOpenAIClient.isURL(azureResourceNameOrURL) ? azureResourceNameOrURL : "https://" + azureResourceNameOrURL + ".openai.azure.com/openai";
        HashMap<String, String> headers = new HashMap<String, String>();
        headers.put("api-key", azureAPIKey);
        RawOpenAIClient.addCustomHeaders(headers, customHeaders, projectKey);
        RawOpenAIClient r = new RawOpenAIClient((String)url, headers, networkSettings, proxySettings, trustAllSSLCertificates);
        r.azureAPIVersion = "2024-12-01-preview";
        r.azureImageAPIVersion = "2024-02-01";
        r.mode = OpenAIMode.AZURE_OPENAI;
        r.useMaxCompletionToken = true;
        return r;
    }

    public static RawOpenAIClient forAzureWithOAuthToken(String azureResourceNameOrURL, String bearerToken, List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders, @Nullable String projectKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean trustAllSSLCertificates) {
        Object url = RawOpenAIClient.isURL(azureResourceNameOrURL) ? azureResourceNameOrURL : "https://" + azureResourceNameOrURL + ".openai.azure.com/openai";
        HashMap<String, String> headers = new HashMap<String, String>();
        headers.put("Authorization", "Bearer " + bearerToken);
        RawOpenAIClient.addCustomHeaders(headers, customHeaders, projectKey);
        RawOpenAIClient r = new RawOpenAIClient((String)url, headers, networkSettings, proxySettings, trustAllSSLCertificates);
        r.azureAPIVersion = "2024-12-01-preview";
        r.azureImageAPIVersion = "2024-02-01";
        r.mode = OpenAIMode.AZURE_OPENAI;
        r.useMaxCompletionToken = true;
        return r;
    }

    public static RawOpenAIClient forAzureLLMWithAPIKey(String url, String azureAPIKey, List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders, @Nullable String projectKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean trustAllSSLCertificates) {
        HashMap<String, String> headers = new HashMap<String, String>();
        if (StringUtils.isNotBlank((String)azureAPIKey)) {
            headers.put("Authorization", "Bearer " + azureAPIKey);
        } else {
            logger.debugV("No key specified, not adding Authorization header", new Object[0]);
        }
        RawOpenAIClient.addCustomHeaders(headers, customHeaders, projectKey);
        RawOpenAIClient r = new RawOpenAIClient(url, headers, networkSettings, proxySettings, trustAllSSLCertificates);
        r.mode = OpenAIMode.AZURE_LLM;
        r.client.allowChunkedEncoding = true;
        return r;
    }

    private static void addCustomHeaders(Map<String, String> headers, List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders, @Nullable String projectKey) {
        VariablesService variablesService = (VariablesService)SpringUtils.getBean(VariablesService.class);
        VariablesContext variablesContext = StringUtils.isEmpty((String)projectKey) ? variablesService.getGlobalContext() : variablesService.getContext(projectKey);
        for (AbstractSQLConnection.CustomDatabaseProperty customHeader : customHeaders) {
            String headerName = variablesContext.expandAllowUnresolved(customHeader.name);
            String headerValue = variablesContext.expandAllowUnresolved(customHeader.value);
            headerName = headerName.replaceAll("[^\\w-]*", "");
            headers.put(headerName, headerValue);
        }
    }

    private static boolean isURL(String nameOrURL) {
        return nameOrURL.startsWith("http://") || nameOrURL.startsWith("https://");
    }

    private RawOpenAIClient(String endpointBase, Map<String, String> headers, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean trustAllSSLCertificates) {
        this.networkSettings = networkSettings;
        this.client = new ExternalJSONAPIClient(endpointBase, null, trustAllSSLCertificates, proxySettings, OnlineLLMUtils.getLLMResponseRetryStrategy(networkSettings), builder -> OnlineLLMUtils.add429RetryStrategy(builder, networkSettings)){

            protected HttpPost newPost(String path) throws Exception {
                HttpPost post = super.newPost(path);
                post.addHeader("Content-Type", "application/json");
                return post;
            }
        };
        for (Map.Entry<String, String> entry : headers.entrySet()) {
            this.client.addHeader(entry.getKey(), entry.getValue());
        }
    }

    public void close() {
        this.client.close();
    }

    private String getCompletionEndpoint(String model) {
        return switch (this.mode) {
            case OpenAIMode.OPENAI -> "/completions";
            case OpenAIMode.AZURE_OPENAI -> "deployments/" + model + "/completions?api-version=" + this.azureAPIVersion;
            case OpenAIMode.AZURE_LLM -> "/v1/completions";
            default -> throw new IllegalStateException("Unexpected mode: " + String.valueOf((Object)this.mode));
        };
    }

    private String getChatCompleteEndpoint(String model) {
        return switch (this.mode) {
            case OpenAIMode.OPENAI -> "/chat/completions";
            case OpenAIMode.AZURE_OPENAI -> "deployments/" + model + "/chat/completions?api-version=" + this.azureAPIVersion;
            case OpenAIMode.AZURE_LLM -> "/v1/chat/completions";
            default -> throw new IllegalStateException("Unexpected mode: " + String.valueOf((Object)this.mode));
        };
    }

    public LLMClient.SimpleCompletionResponse complete(String model, String prompt, CoreCompletionSettings ccs) throws IOException {
        textCompletionValidator.validate(this.mode, ccs);
        String endpoint = this.getCompletionEndpoint(model);
        OpenAICompletionQuery query = OpenAICompletionQueryAdapter.adapt(this.mode, model, prompt, ccs);
        logger.trace(() -> String.format("OpenAI completion endpoint (mode: %s): %s", this.mode.name(), endpoint));
        logger.trace(() -> String.format("OpenAI raw completion query (mode: %s): %s", this.mode.name(), JSON.pretty((Object)query)));
        OpenAICompletionResponse response = (OpenAICompletionResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, OpenAICompletionResponse.class, (Object)query);
        logger.trace(() -> String.format("OpenAI raw completion response (mode: %s): %s", this.mode.name(), JSON.pretty((Object)response)));
        return OpenAICompletionResponseAdapter.adapt(response);
    }

    private boolean shouldUseMaxCompletionTokens(LLMModelHandle.Model model) {
        if (model instanceof AzureOpenAIConnection.AzureOpenAIModel) {
            return ((AzureOpenAIConnection.AzureOpenAIModel)model).maxTokensAPIMode == AzureOpenAIConnection.AzureOpenAIMaxTokensAPIMode.MODERN;
        }
        return this.useMaxCompletionToken;
    }

    public LLMClient.SimpleCompletionResponse chatComplete(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs, LLMModelHandle.Model modelForCostEstimation) throws IOException {
        chatCompletionValidator.validate(this.mode, ccs);
        String endpoint = this.getChatCompleteEndpoint(model);
        boolean shouldUseMaxCompletionTokens = this.shouldUseMaxCompletionTokens(modelForCostEstimation);
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adapt(this.mode, model, messages, ccs, shouldUseMaxCompletionTokens);
        logger.trace(() -> String.format("OpenAI chat completion endpoint (mode: %s): %s", this.mode.name(), endpoint));
        logger.trace(() -> String.format("OpenAI raw chat completion query (mode: %s): %s", this.mode.name(), JSON.pretty((Object)query)));
        OpenAIChatResponse response = (OpenAIChatResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, OpenAIChatResponse.class, (Object)query);
        logger.trace(() -> String.format("OpenAI raw chat completion response (mode: %s): %s", this.mode.name(), JSON.pretty((Object)response)));
        String refusal = OpenAIChatResponseAdapter.getRefusal(response);
        if (refusal != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusal);
            if (response.usage != null) {
                refusalException.completionTokens = response.usage.completionTokens;
                refusalException.promptTokens = response.usage.promptTokens;
                refusalException.totalTokens = response.usage.totalTokens;
                refusalException.estimatedCost = modelForCostEstimation.getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }
        return OpenAIChatResponseAdapter.adapt(response);
    }

    public void streamChatComplete(LLMClient.StreamedCompletionResponseConsumer consumer, LLMModelHandle.Model model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws Exception {
        chatCompletionValidator.validate(this.mode, ccs);
        String endpoint = this.getChatCompleteEndpoint(model.getId());
        boolean shouldUseMaxCompletionTokens = this.shouldUseMaxCompletionTokens(model);
        OpenAIChatQuery query = OpenAIChatQueryAdapter.adaptForStreaming(this.mode, model.getId(), messages, ccs, shouldUseMaxCompletionTokens);
        logger.trace(() -> String.format("OpenAI chat completion endpoint (mode: %s): %s", this.mode.name(), endpoint));
        logger.trace(() -> String.format("OpenAI raw chat completion streaming query (mode: %s): %s", this.mode.name(), JSON.pretty((Object)query)));
        ExternalJSONAPIClient.EntityAndRequest ear = this.client.postJSONToStreamAndRequest(endpoint, this.networkSettings.queryTimeoutMS, (Object)query);
        SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
        consumer.onStreamStarted();
        OpenAIChatResponse.Usage usage = null;
        LLMClient.FinishReason finishReason = null;
        StringBuilder refusalBuilder = null;
        while (true) {
            LLMClient.StreamedCompletionResponseChunk chunk;
            String refusalChunk;
            SSEDecoder.HTTPSSEEvent event = decoder.next();
            if (logger.isTraceEnabled()) {
                logger.trace((Object)("Received raw event from OpenAI: " + JSON.json((Object)event)));
            }
            if (event == null || event.data == null) {
                logger.info((Object)"End of OpenAI stream");
                break;
            }
            if (event.data.equals("[DONE]")) {
                logger.info((Object)"Received explicit end marker from OpenAI stream");
                break;
            }
            OpenAIChatChunkResponse response = (OpenAIChatChunkResponse)JSON.parse((String)event.data, OpenAIChatChunkResponse.class);
            logger.trace(() -> String.format("OpenAI raw streamed chat completion response chunk (mode: %s): %s", this.mode.name(), JSON.pretty((Object)response)));
            if (response.usage != null) {
                usage = response.usage;
            }
            if (response.choices.isEmpty()) continue;
            LLMClient.FinishReason reason = OpenAIChatChunkResponseAdapter.extractFinishReason(response);
            if (reason != null) {
                finishReason = reason;
            }
            if ((refusalChunk = OpenAIChatChunkResponseAdapter.getRefusal(response)) != null) {
                if (refusalBuilder == null) {
                    refusalBuilder = new StringBuilder();
                }
                refusalBuilder.append(refusalChunk);
            }
            if (refusalBuilder != null || (chunk = OpenAIChatChunkResponseAdapter.adapt(response)).isEmpty()) continue;
            consumer.onStreamChunk(chunk);
        }
        if (refusalBuilder != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusalBuilder.toString());
            if (usage != null) {
                refusalException.completionTokens = usage.completionTokens;
                refusalException.promptTokens = usage.promptTokens;
                refusalException.totalTokens = usage.totalTokens;
                refusalException.estimatedCost = model.getEstimatedCompletionCost(refusalException.promptTokens, refusalException.completionTokens);
            }
            throw refusalException;
        }
        LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
        if (usage != null) {
            footer.completionTokens = usage.completionTokens;
            footer.promptTokens = usage.promptTokens;
            footer.totalTokens = usage.totalTokens;
            footer.estimatedCost = model.getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
        }
        if (finishReason != null) {
            footer.finishReason = finishReason;
        }
        consumer.onStreamComplete(footer);
    }

    public LLMClient.ImageGenerationResponse generateImage(OpenAIImageHandling imageHandlingMode, String model, LLMClient.ImageGenerationQuery query) throws IOException {
        String prompt = query.getConcatenatedPrompts();
        String size = ImageGenerationUtils.getWxHString(query, 1024, 1024);
        if (OpenAIImageHandling.DALL_E_3 == imageHandlingMode && query.nbImagesToGenerate != null && query.nbImagesToGenerate > 1) {
            throw new IllegalArgumentException("Dall-E 3 can only generate one image at a time");
        }
        if (!query.negativePrompts.isEmpty()) {
            throw new IllegalArgumentException("Negative prompts are not supported by OpenAI models");
        }
        if (query.fidelity != null) {
            throw new IllegalArgumentException("Fidelity parameter is not supported by OpenAI models");
        }
        if (query.seed != null) {
            throw new IllegalArgumentException("Seed parameter is not supported by OpenAI models");
        }
        JF.ObjectBuilder ob = JF.obj();
        if (OpenAIMode.OPENAI == this.mode) {
            ob.with("model", model);
        }
        ob.with("prompt", prompt);
        if (query.nbImagesToGenerate != null) {
            ob.with("n", (Number)query.nbImagesToGenerate);
        }
        if (query.quality != null) {
            if (OpenAIImageHandling.GPT_IMAGE_1 == imageHandlingMode) {
                if (ImageGenerationUtils.isHighQualitySynonym(query.quality)) {
                    ob.with("quality", "high");
                } else if (ImageGenerationUtils.isStandardQualitySynonym(query.quality)) {
                    ob.with("quality", "medium");
                } else if (ImageGenerationUtils.isLowQualitySynonym(query.quality)) {
                    ob.with("quality", "low");
                } else {
                    logger.warn((Object)("Invalid quality setting for gpt-image-1: " + query.quality));
                }
            } else if (OpenAIImageHandling.DALL_E_3 == imageHandlingMode) {
                if (ImageGenerationUtils.isHighQualitySynonym(query.quality)) {
                    ob.with("quality", "hd");
                } else if (ImageGenerationUtils.isStandardQualitySynonym(query.quality)) {
                    ob.with("quality", "standard");
                } else {
                    logger.warn((Object)("Invalid quality setting for DALL-E-3: " + query.quality));
                }
            } else {
                logger.warn((Object)("Ignoring quality setting for " + imageHandlingMode.name()));
            }
        }
        ob.with("size", size);
        if (query.style != null) {
            if (OpenAIImageHandling.DALL_E_3 == imageHandlingMode) {
                ob.with("style", query.style);
            } else {
                logger.warn((Object)"Style parameter is unsupported by the current model and has been ignored. This parameter is only available for DALL-E 3.");
            }
        }
        if (OpenAIImageHandling.DALL_E_3 == imageHandlingMode) {
            ob.with("response_format", "b64_json");
        }
        logger.info((Object)("Raw OpenAI image generation: " + JSON.pretty((Object)ob.get())));
        String endpoint = this.getImageGenerationEndpoint(model);
        RawImageGenerationResponse rcr = (RawImageGenerationResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, RawImageGenerationResponse.class, (Object)ob.get());
        if (rcr.data == null || rcr.data.size() == 0) {
            throw new IOException("OpenAI did not respond with valid data");
        }
        LLMClient.ImageGenerationResponse ret = new LLMClient.ImageGenerationResponse();
        ret.images = rcr.data.stream().map(i -> i.b64_json).map(LLMClient.ImageGenerationImage::new).collect(Collectors.toList());
        return ret;
    }

    private String getImageGenerationEndpoint(String model) {
        return switch (this.mode) {
            case OpenAIMode.OPENAI -> "/images/generations";
            case OpenAIMode.AZURE_OPENAI -> "deployments/" + model + "/images/generations?api-version=" + this.azureImageAPIVersion;
            case OpenAIMode.AZURE_LLM -> "/v1/images/generations";
            default -> throw new IllegalStateException("Unexpected mode: " + String.valueOf((Object)this.mode));
        };
    }

    public SimpleModerationResponse moderate(String text) throws IOException {
        assert (this.mode == OpenAIMode.OPENAI);
        String endpoint = "/moderations";
        JF.ObjectBuilder ob = JF.obj().with("model", "omni-moderation-latest").with("input", text);
        OpenAIModerationResponse rcr = (OpenAIModerationResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, OpenAIModerationResponse.class, (Object)ob.get());
        logger.info((Object)("raw moderation response: " + JSON.json((Object)rcr)));
        if (rcr.results.size() != 1) {
            throw new IOException("OpenAI did not respond with valid moderation");
        }
        SimpleModerationResponse ret = new SimpleModerationResponse();
        ret.flagged = rcr.results.get((int)0).flagged;
        if (ret.flagged) {
            ret.flaggedCategories = rcr.results.get((int)0).categories.entrySet().stream().filter(e -> (Boolean)e.getValue()).map(e -> (String)e.getKey()).collect(Collectors.toList());
        }
        return ret;
    }

    public List<LLMClient.SimpleEmbeddingResponse> embed(String model, List<String> batchTexts) throws IOException {
        String endpoint = this.getEmbeddingEndpoint(model);
        JF.ObjectBuilder ob = JF.obj().with("input", batchTexts).with("model", model);
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw OpenAI embedding query: " + JSON.json((Object)ob.get())));
        }
        OpenAIEmbeddingResponse rcr = (OpenAIEmbeddingResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, OpenAIEmbeddingResponse.class, (Object)ob.get());
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw OpenAI embedding response: " + JSON.json((Object)rcr)));
        }
        if (rcr.data.size() != batchTexts.size()) {
            throw new IOException("OpenAI did not respond with valid embeddings");
        }
        ArrayList<LLMClient.SimpleEmbeddingResponse> batchResponses = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (OpenAIEmbeddingResult singleResult : rcr.data) {
            LLMClient.SimpleEmbeddingResponse singleResponse = new LLMClient.SimpleEmbeddingResponse();
            singleResponse.embedding = singleResult.embedding;
            singleResponse.promptTokens = rcr.usage.total_tokens / batchTexts.size();
            batchResponses.add(singleResponse);
        }
        return batchResponses;
    }

    @Nonnull
    private String getEmbeddingEndpoint(String model) {
        return switch (this.mode) {
            case OpenAIMode.OPENAI -> "/embeddings";
            case OpenAIMode.AZURE_OPENAI -> "deployments/" + model + "/embeddings?api-version=" + this.azureAPIVersion;
            case OpenAIMode.AZURE_LLM -> "/v1/embeddings";
            default -> throw new IllegalStateException("Unexpected mode: " + String.valueOf((Object)this.mode));
        };
    }

    public String uploadFile(File f, String filename, String purpose) throws IOException {
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = "/files" + (String)(this.mode == OpenAIMode.AZURE_OPENAI ? "?api-version=" + this.azureAPIVersion : "");
        JsonObject uploadResp = (JsonObject)this.client.postFormAndFileToJSON(endpoint, JsonObject.class, f, filename, new Object[]{"purpose", purpose});
        logger.info((Object)("raw upload response: " + JSON.json((Object)uploadResp)));
        return uploadResp.get("id").getAsString();
    }

    public void deleteFile(String fileId) throws IOException {
        Object[] objectArray;
        String endpoint = "/files/" + fileId;
        assert (this.mode != OpenAIMode.AZURE_LLM);
        if (this.mode == OpenAIMode.AZURE_OPENAI) {
            String[] stringArray = new String[2];
            stringArray[0] = "api-version";
            objectArray = stringArray;
            stringArray[1] = this.azureAPIVersion;
        } else {
            objectArray = new String[]{};
        }
        Object[] queryParams = objectArray;
        JsonObject deleteResp = (JsonObject)this.client.delete(endpoint, JsonObject.class, queryParams);
        logger.info((Object)("raw delete response: " + JSON.json((Object)deleteResp)));
    }

    public String getFileStatus(String fileId) throws IOException {
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = "/files/" + fileId + (String)(this.mode == OpenAIMode.AZURE_OPENAI ? "?api-version=" + this.azureAPIVersion : "");
        JsonObject response = (JsonObject)this.client.getToJSON(endpoint, JsonObject.class, new Object[0]);
        return response.get("status").getAsString();
    }

    public String fineTuneStart(String trainingFileId, Optional<String> validationFileId, String model, FineTuningRecipePayloadParams.FineTuningHyperparameters hyperparameters, boolean useDefaults) throws IOException {
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = this.mode == OpenAIMode.AZURE_OPENAI ? "/fine_tuning/jobs?api-version=" + this.azureAPIVersion : "/fine_tuning/jobs";
        JF.ObjectBuilder hyperparametersObj = JF.obj();
        if (!useDefaults) {
            if (hyperparameters.nbEpochs != null) {
                hyperparametersObj.with("n_epochs", hyperparameters.nbEpochs.toString());
            }
            if (hyperparameters.remoteHyperparameters.batchSize != null) {
                hyperparametersObj.with("batch_size", hyperparameters.remoteHyperparameters.batchSize.toString());
            }
            if (hyperparameters.remoteHyperparameters.learningRateMultiplier != null) {
                hyperparametersObj.with("learning_rate_multiplier", hyperparameters.remoteHyperparameters.learningRateMultiplier.toString());
            }
        }
        JF.ObjectBuilder ob = JF.obj().with("training_file", trainingFileId).with("validation_file", (String)validationFileId.orElse(null)).with("model", model).with("hyperparameters", (JsonElement)hyperparametersObj.get());
        logger.info((Object)("Raw fine tune query: " + JSON.json((Object)ob.get())));
        JsonObject ftResp = (JsonObject)this.client.postObjectToJSON(endpoint, JsonObject.class, (Object)ob.get());
        logger.info((Object)("raw upload response: " + JSON.json((Object)ftResp)));
        return ftResp.get("id").getAsString();
    }

    private JsonObject getJobEvents(String id, Optional<String> after) throws IOException {
        LinkedList<String> queryParts = new LinkedList<String>();
        if (after.isPresent()) {
            queryParts.add("after");
            queryParts.add(after.get());
        }
        assert (this.mode != OpenAIMode.AZURE_LLM);
        if (OpenAIMode.AZURE_OPENAI == this.mode) {
            queryParts.add("api-version");
            queryParts.add(this.azureAPIVersion);
        }
        String endpoint = "fine_tuning/jobs/" + id + "/events";
        JsonObject ftGetResp = (JsonObject)this.client.getToJSON(endpoint, JsonObject.class, queryParts.toArray());
        return ftGetResp;
    }

    public void fillLLMStepwiseTrainingMetrics(String id, LLMStepwiseTrainingMetrics llmStepwiseTrainingMetrics, LLMSavedModelInfo llmSmi) throws IOException {
        boolean hasMore = true;
        int totalSteps = 0;
        Optional<String> after = Optional.empty();
        HashMap<Integer, LLMStepwiseTrainingMetrics.FineTuningJobMetric> metrics = new HashMap<Integer, LLMStepwiseTrainingMetrics.FineTuningJobMetric>();
        while (hasMore) {
            JsonObject currentResponse = this.getJobEvents(id, after);
            JsonArray data = currentResponse.get("data").getAsJsonArray();
            for (JsonElement el : data) {
                if (!el.getAsJsonObject().get("type").getAsString().equals("metrics")) continue;
                JsonObject metricJsonObject = el.getAsJsonObject().get("data").getAsJsonObject();
                LLMStepwiseTrainingMetrics.FineTuningJobMetric metric = new LLMStepwiseTrainingMetrics.FineTuningJobMetric(metricJsonObject.get("step").getAsInt(), new LLMStepwiseTrainingMetrics.MetricValue(LLMStepwiseTrainingMetrics.MetricType.TRAINING, metricJsonObject.get("train_loss").getAsFloat()));
                if (metricJsonObject.has("valid_loss")) {
                    metric.setValidationMetric(metricJsonObject.get("valid_loss").getAsFloat(), Optional.empty());
                }
                if (metricJsonObject.has("full_valid_loss")) {
                    metric.setFullValidationMetric(metricJsonObject.get("full_valid_loss").getAsFloat(), Optional.empty());
                }
                metrics.put(metricJsonObject.get("step").getAsInt(), metric);
                after = Optional.of(el.getAsJsonObject().get("id").getAsString());
                if (!metricJsonObject.has("total_steps")) continue;
                totalSteps = Math.max(metricJsonObject.get("total_steps").getAsInt(), totalSteps);
            }
            hasMore = currentResponse.get("has_more").getAsBoolean();
        }
        llmStepwiseTrainingMetrics.metrics = metrics;
        llmSmi.totalSteps = totalSteps;
    }

    public String getBestCheckpointModel(String jobId, String finalModelId) throws IOException {
        JsonObject checkpoints = this.getFineTuningCheckpoints(jobId);
        String bestModelId = finalModelId;
        float bestTrainLoss = Float.MAX_VALUE;
        JsonArray data = checkpoints.get("data").getAsJsonArray();
        for (JsonElement checkpoint : data) {
            JsonElement metrics = checkpoint.getAsJsonObject().get("metrics");
            float currentTrainLoss = metrics.getAsJsonObject().get("train_loss").getAsFloat();
            if (!(currentTrainLoss < bestTrainLoss)) continue;
            bestTrainLoss = currentTrainLoss;
            bestModelId = checkpoint.getAsJsonObject().get("fine_tuned_model_checkpoint").getAsString();
        }
        return bestModelId;
    }

    public JsonObject getFineTuningCheckpoints(String id) throws IOException {
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = "/fine_tuning/jobs/" + id + "/checkpoints" + (String)(OpenAIMode.AZURE_OPENAI == this.mode ? "?api-version=" + this.azureAPIVersion : "");
        JsonObject ftGetResp = (JsonObject)this.client.getToJSON(endpoint, JsonObject.class, new Object[0]);
        logger.info((Object)("raw ftCheckpoints response: " + JSON.json((Object)ftGetResp)));
        return ftGetResp;
    }

    public JsonObject fineTuneGet(String id) throws IOException {
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = "/fine_tuning/jobs/" + id + (String)(OpenAIMode.AZURE_OPENAI == this.mode ? "?api-version=" + this.azureAPIVersion : "");
        JsonObject ftGetResp = (JsonObject)this.client.getToJSON(endpoint, JsonObject.class, new Object[0]);
        logger.info((Object)("raw ftStatus response: " + JSON.json((Object)ftGetResp)));
        return ftGetResp;
    }

    public JsonObject fineTuneCancel(String id) throws IOException {
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = "/fine_tuning/jobs/" + id + "/cancel" + (String)(OpenAIMode.AZURE_OPENAI == this.mode ? "?api-version=" + this.azureAPIVersion : "");
        JsonObject ftPostResp = (JsonObject)this.client.postObjectToJSON(endpoint, JsonObject.class, (Object)JF.obj().get());
        logger.info((Object)("raw ft cancel response: " + JSON.json((Object)ftPostResp)));
        return ftPostResp;
    }

    public void deleteFinetunedModel(String jobId) throws IOException {
        Object[] objectArray;
        assert (this.mode != OpenAIMode.AZURE_LLM);
        String endpoint = "/fine_tuning/jobs/" + jobId;
        if (this.mode == OpenAIMode.AZURE_OPENAI) {
            String[] stringArray = new String[2];
            stringArray[0] = "api-version";
            objectArray = stringArray;
            stringArray[1] = this.azureAPIVersion;
        } else {
            objectArray = new String[]{};
        }
        Object[] queryParams = objectArray;
        try {
            JsonObject ftDeleteResp = (JsonObject)this.client.delete(endpoint, JsonObject.class, queryParams);
            logger.info((Object)("raw ft delete response: " + JSON.json((Object)ftDeleteResp)));
        }
        catch (ExternalJSONAPIClient.JSONAPIClientException e) {
            if (e.httpCode == 404) {
                logger.warn((Object)("Fine-tuned model with job id " + jobId + " not found, ignoring"));
            }
            throw e;
        }
    }

    private static class RawImageGenerationResponse {
        List<RawImage> data;

        private RawImageGenerationResponse() {
        }
    }

    private static class OpenAIModerationResponse {
        List<OpenAIModerationResult> results = new ArrayList<OpenAIModerationResult>();

        private OpenAIModerationResponse() {
        }
    }

    public static class SimpleModerationResponse {
        public boolean flagged;
        public List<String> flaggedCategories = new ArrayList<String>();
    }

    private static class OpenAIModerationResult {
        boolean flagged;
        Map<String, Boolean> categories;
        Map<String, Double> category_scores;

        private OpenAIModerationResult() {
        }
    }

    private static class OpenAIEmbeddingResponse {
        List<OpenAIEmbeddingResult> data = new ArrayList<OpenAIEmbeddingResult>();
        RawUsageResponse usage;

        private OpenAIEmbeddingResponse() {
        }
    }

    private static class OpenAIEmbeddingResult {
        double[] embedding;

        private OpenAIEmbeddingResult() {
        }
    }

    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;

        private RawUsageResponse() {
        }
    }

    private static class RawImage {
        String b64_json;

        private RawImage() {
        }
    }
}

