/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.stabilityai;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.StabilityAIConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.utils.AspectRatioMatcher;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.http.HttpEntity;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPost;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpUriRequest;
import com.dataiku.dss.shadelib.org.apache.http.entity.ContentType;
import com.dataiku.dss.shadelib.org.apache.http.entity.mime.HttpMultipartMode;
import com.dataiku.dss.shadelib.org.apache.http.entity.mime.MultipartEntityBuilder;
import com.google.gson.JsonObject;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import org.apache.commons.codec.binary.Base64;

public class RawStabilityAIClient {
    private static final String DEFAULT_ENDPOINT_BASE = "https://api.stability.ai";
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    private ExternalJSONAPIClient client;
    private static final AspectRatioMatcher aspectRatioMatcher = new AspectRatioMatcher("1:1", "16:9", "9:16", "21:9", "9:21", "2:3", "3:2", "4:5", "5:4");
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.stabilityai.client");
    private static final int MAX_JSON_SAMPLE_LOG_LENGTH = 120;

    public RawStabilityAIClient(String apiKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean forceContentLength) {
        this.networkSettings = networkSettings;
        this.client = OnlineLLMUtils.getExternalJSONClientWithRetryStrategy(DEFAULT_ENDPOINT_BASE, null, false, proxySettings, networkSettings);
        this.client.addHeader("Authorization", "Bearer " + apiKey);
        if (forceContentLength) {
            this.client.forceContentLength = true;
        }
    }

    public void close() {
        this.client.close();
    }

    public LLMClient.ImageGenerationResponse generateImages(String model, LLMClient.ImageGenerationQuery query) throws Exception {
        HttpEntity entity;
        MultipartEntityBuilder mpb = MultipartEntityBuilder.create();
        mpb.setMode(HttpMultipartMode.STRICT);
        String prompt = query.getConcatenatedPrompts();
        mpb.addTextBody("prompt", prompt);
        if (!query.negativePrompts.isEmpty()) {
            String negativePrompt = query.getConcatenatedNegativePrompts();
            mpb.addTextBody("negative_prompt", negativePrompt);
        }
        if (query.seed != null) {
            mpb.addTextBody("seed", "" + query.seed);
        }
        if (query.style != null) {
            mpb.addTextBody("style_preset", query.style);
        }
        if (query.height != null && query.width != null) {
            String ar = aspectRatioMatcher.fuzzyMatch(query.width, query.height);
            mpb.addTextBody("aspect_ratio", ar);
        }
        mpb.addTextBody("output_format", "png");
        String path = null;
        if (StabilityAIConnection.HardcodedStabilityAIModel.STABLE_IMAGE_CORE.id.equals(model)) {
            path = "/v2beta/stable-image/generate/core";
            if (query.originalImage != null) {
                throw new IllegalArgumentException("Image-to-image mode not supported on Stable Image Core, use Stable Diffusion 3 instead");
            }
        } else if (StabilityAIConnection.HardcodedStabilityAIModel.STABLE_IMAGE_ULTRA.id.equals(model)) {
            path = "/v2beta/stable-image/generate/ultra";
            if (query.originalImage != null) {
                throw new IllegalArgumentException("Image-to-image mode not supported on Stable Image Ultra, use Stable Diffusion 3 instead");
            }
        } else {
            path = "/v2beta/stable-image/generate/sd3";
            mpb.addTextBody("model", model);
            if (StabilityAIConnection.HardcodedStabilityAIModel.STABLE_DIFFUSION_30_LARGE_TURBO.id.equals(model) && !query.negativePrompts.isEmpty()) {
                throw new IllegalArgumentException("negative prompts not supported on SD3 Large Turbo");
            }
            if (query.originalImage != null) {
                logger.info((Object)"Enabling image-to-image mode");
                mpb.addBinaryBody("image", Base64.decodeBase64((String)query.originalImage), ContentType.create((String)"application/png"), "image.png");
                if (query.originalImageEditionMode == LLMClient.ImageGenerationEditionMode.CONTROLNET_SKETCH) {
                    path = "/v2beta/stable-image/control/sketch";
                    if (query.originalImageWeight != null) {
                        mpb.addTextBody("control_strength", "" + query.originalImageWeight);
                    }
                } else if (query.originalImageEditionMode == LLMClient.ImageGenerationEditionMode.CONTROLNET_STRUCTURE) {
                    path = "/v2beta/stable-image/control/structure";
                    if (query.originalImageWeight != null) {
                        mpb.addTextBody("control_strength", "" + query.originalImageWeight);
                    }
                } else {
                    mpb.addTextBody("mode", "image-to-image");
                    if (query.originalImageWeight != null) {
                        mpb.addTextBody("strength", String.valueOf(1.0 - query.originalImageWeight));
                    } else {
                        mpb.addTextBody("strength", "0.5");
                    }
                }
            }
        }
        if (logger.isTraceEnabled()) {
            entity = mpb.build();
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            entity.writeTo((OutputStream)baos);
            logger.trace((Object)("Raw StabilityAI request:\n" + baos.toString(StandardCharsets.UTF_8)));
        }
        entity = mpb.build();
        HttpPost post = this.client.newAnyPost(path, -1, entity);
        post.addHeader("Accept", "application/json");
        JsonObject ret = (JsonObject)this.client.executeToJSON((HttpUriRequest)post, JsonObject.class);
        logger.info((Object)("RAW RET" + JSON.sampleJson((Object)ret, (int)120)));
        LLMClient.ImageGenerationResponse resp = new LLMClient.ImageGenerationResponse();
        resp.images.add(new LLMClient.ImageGenerationImage(ret.get("image").getAsString()));
        return resp;
    }
}

