/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.sagemakergeneric;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.llm.utils.AspectRatioMatcher;
import com.dataiku.dip.llm.utils.ImageGenerationUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.google.common.collect.Iterables;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import org.apache.commons.lang.StringUtils;

public interface GenericImageGenerationLLMMarshall {
    public static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.aws.marshalling");

    public JsonObject prepareImageGenerationQuery(LLMClient.ImageGenerationQuery var1);

    public LLMClient.ImageGenerationResponse parseImageGenerationResponse(JsonElement var1, LLMClient.ImageGenerationQuery var2) throws IOException;

    public static GenericImageGenerationLLMMarshall get(GenericLLMHandling family) {
        switch (family) {
            case AMAZON_TITAN: {
                return new AmazonTitanImageGenerationMarshall();
            }
            case STABILITYAI_STABLE_DIFFUSION_10: {
                return new StabilityAIStableDiffusion10ImageGenerationMarshall();
            }
            case STABILITYAI_STABLE_IMAGE_CORE: {
                return new StabilityAIStableImageGenerationMarshall("Stable Image Core", false);
            }
            case STABILITYAI_STABLE_DIFFUSION_3: {
                return new StabilityAIStableImageGenerationMarshall("Stable Diffusion 3", true);
            }
            case STABILITYAI_STABLE_IMAGE_ULTRA: {
                return new StabilityAIStableImageGenerationMarshall("Stable Image Ultra", false);
            }
        }
        throw new Error("Unknown GenericLLMHandling family for image generation: " + String.valueOf((Object)family));
    }

    public static class AmazonTitanImageGenerationMarshall
    implements GenericImageGenerationLLMMarshall {
        @Override
        public JsonObject prepareImageGenerationQuery(LLMClient.ImageGenerationQuery query) {
            JF.ObjectBuilder ob = JF.obj();
            if (query.originalImage == null) {
                ob.with("taskType", "TEXT_IMAGE");
                JF.ObjectBuilder textToImageParams = JF.obj();
                textToImageParams.with("text", query.getConcatenatedPrompts());
                if (!query.negativePrompts.isEmpty()) {
                    textToImageParams.with("negativeText", query.getConcatenatedNegativePrompts());
                }
                ob.with("textToImageParams", (JsonElement)textToImageParams.get());
            } else {
                query.throwIfNullOriginalImageEditionMode();
                switch (query.originalImageEditionMode) {
                    case VARY: 
                    case MASK_FREE: {
                        ob.with("taskType", "IMAGE_VARIATION");
                        JF.ObjectBuilder imageVariationParams = JF.obj();
                        JsonArray images = new JsonArray();
                        images.add(query.originalImage);
                        imageVariationParams.with("images", (JsonElement)images);
                        if (!query.prompts.isEmpty()) {
                            imageVariationParams.with("text", query.getConcatenatedPrompts());
                        }
                        if (!query.negativePrompts.isEmpty()) {
                            imageVariationParams.with("negativeText", query.getConcatenatedNegativePrompts());
                        }
                        if (query.originalImageWeight != null) {
                            double similarityStrength = query.originalImageWeight;
                            if (similarityStrength < 0.2) {
                                logger.warn((Object)String.format("Ignoring originalImageWeight '%f' below 0.2 on Titan, using 0.2.", similarityStrength));
                                similarityStrength = 0.2;
                            }
                            if (similarityStrength > 1.0) {
                                logger.warn((Object)String.format("Ignoring originalImageWeight '%f' above 1.0 on Titan, using 1.0.", similarityStrength));
                                similarityStrength = 1.0;
                            }
                            imageVariationParams.with("similarityStrength", (Number)similarityStrength);
                        }
                        ob.with("imageVariationParams", (JsonElement)imageVariationParams.get());
                        break;
                    }
                    case INPAINTING: {
                        if (query.maskMode == null) {
                            throw new IllegalArgumentException("Invalid mask mode while using inpainting.");
                        }
                        ob.with("taskType", "INPAINTING");
                        JF.ObjectBuilder inPaintingParams = JF.obj();
                        switch (query.maskMode) {
                            case MASK_IMAGE_ALPHA: {
                                throw new IllegalArgumentException("INPAINTING image editing mode does not support MASK_IMAGE_ALPHA mask mode.");
                            }
                            case MASK_IMAGE_BLACK: {
                                if (query.maskImage == null) {
                                    throw new IllegalArgumentException("When using MASK_IMAGE_BLACK mask mode a mask image must be specified.");
                                }
                                inPaintingParams.with("maskImage", query.maskImage);
                                break;
                            }
                            case ORIGINAL_IMAGE_ALPHA: {
                                throw new IllegalArgumentException("INPAINTING image editing mode does not support ORIGINAL_IMAGE_ALPHA mask mode.");
                            }
                            case TEXT: {
                                if (query.maskPrompt == null) {
                                    throw new IllegalArgumentException("When using TEXT mask mode a mask prompt must be specified to generate a mask.");
                                }
                                inPaintingParams.with("maskPrompt", query.maskPrompt);
                                break;
                            }
                            default: {
                                throw new IllegalArgumentException("Unsupported mask mode on Titan: " + String.valueOf((Object)query.maskMode));
                            }
                        }
                        inPaintingParams.with("image", query.originalImage);
                        inPaintingParams.with("text", query.getConcatenatedPrompts());
                        inPaintingParams.with("negativeText", query.getConcatenatedNegativePrompts());
                        ob.with("inPaintingParams", (JsonElement)inPaintingParams.get());
                        break;
                    }
                    case OUTPAINTING: {
                        throw new NotImplementedException("outpainting");
                    }
                }
            }
            JF.ObjectBuilder imageGenerationConfig = JF.obj();
            if (query.nbImagesToGenerate != null) {
                imageGenerationConfig.with("numberOfImages", (Number)query.nbImagesToGenerate);
            }
            if (query.height != null) {
                imageGenerationConfig.with("height", (Number)query.height);
            }
            if (query.width != null) {
                imageGenerationConfig.with("width", (Number)query.width);
            }
            if (query.seed != null) {
                if (query.originalImage != null) {
                    logger.warn((Object)"Passing a seed is not supported with an original image when using Titan, ignoring seed.");
                } else {
                    imageGenerationConfig.with("seed", (Number)query.seed);
                }
            }
            if (query.quality != null) {
                if (ImageGenerationUtils.isHighQualitySynonym(query.quality)) {
                    imageGenerationConfig.with("quality", "premium");
                } else if (!"standard".equals(query.quality)) {
                    logger.warn((Object)("Invalid quality setting for Titan: " + query.quality));
                }
            }
            if (query.fidelity != null) {
                double cfgScale = 1.1 + query.fidelity * 8.9;
                imageGenerationConfig.with("cfgScale", (Number)cfgScale);
            }
            if (query.style != null) {
                logger.warn((Object)"style specified but ignored for Titan.");
            }
            ob.with("imageGenerationConfig", (JsonElement)imageGenerationConfig.get());
            if (query.originalImage == null && query.maskImage == null) {
                logger.info((Object)("Titan request: " + JSON.json((Object)ob.get())));
            }
            return ob.get();
        }

        @Override
        public LLMClient.ImageGenerationResponse parseImageGenerationResponse(JsonElement response, LLMClient.ImageGenerationQuery query) throws IOException {
            JsonObject jo = (JsonObject)response;
            LLMClient.ImageGenerationResponse ret = new LLMClient.ImageGenerationResponse();
            if (jo.has("error") && jo.get("error").isJsonPrimitive()) {
                throw new IOException("Titan generation failed: " + jo.get("error").getAsString());
            }
            if (jo.has("images") && jo.get("images").isJsonArray()) {
                JsonArray images = jo.get("images").getAsJsonArray();
                for (JsonElement elt : images) {
                    ret.images.add(new LLMClient.ImageGenerationImage(elt.getAsString()));
                }
            }
            return ret;
        }
    }

    public static class StabilityAIStableDiffusion10ImageGenerationMarshall
    implements GenericImageGenerationLLMMarshall {
        @Override
        public JsonObject prepareImageGenerationQuery(LLMClient.ImageGenerationQuery query) {
            JF.ObjectBuilder ob = JF.obj();
            JsonArray text_prompts = new JsonArray();
            ob.with("text_prompts", (JsonElement)text_prompts);
            query.prompts.stream().map(p -> {
                JsonObject o = new JsonObject();
                o.addProperty("text", p.prompt);
                if (p.weight != null) {
                    o.addProperty("weight", (Number)p.weight);
                }
                return o;
            }).forEach(arg_0 -> ((JsonArray)text_prompts).add(arg_0));
            query.negativePrompts.stream().map(p -> {
                JsonObject o = new JsonObject();
                o.addProperty("text", p.prompt);
                if (p.weight != null) {
                    o.addProperty("weight", (Number)(-p.weight.doubleValue()));
                } else {
                    o.addProperty("weight", (Number)-1.0);
                }
                return o;
            }).forEach(arg_0 -> ((JsonArray)text_prompts).add(arg_0));
            if (query.originalImage != null) {
                query.throwIfNullOriginalImageEditionMode();
                block0 : switch (query.originalImageEditionMode) {
                    case OUTPAINTING: {
                        throw new IllegalArgumentException("Outpainting not supported on SDXL");
                    }
                    case VARY: {
                        throw new IllegalArgumentException("No-prompt vary mode not supported on SDXL");
                    }
                    case CONTROLNET_SKETCH: 
                    case CONTROLNET_STRUCTURE: {
                        throw new IllegalArgumentException("Controlnet mode not supported on Bedrock SDXL, use the StabilityAI provider instead.");
                    }
                    case MASK_FREE: {
                        break;
                    }
                    case INPAINTING: {
                        if (query.maskMode == null) {
                            throw new IllegalArgumentException("Mask mode invalid or not specified for inpainting on SDXL");
                        }
                        switch (query.maskMode) {
                            case MASK_IMAGE_ALPHA: {
                                throw new IllegalArgumentException("Mask alpha not supported on SDXL");
                            }
                            case MASK_IMAGE_BLACK: {
                                ob.with("mask_source", "MASK_IMAGE_BLACK");
                                if (query.maskImage == null) {
                                    throw new IllegalArgumentException("Missing mask image");
                                }
                                ob.with("mask_image", query.maskImage);
                                break block0;
                            }
                            case ORIGINAL_IMAGE_ALPHA: {
                                ob.with("mask_source", "INIT_IMAGE_ALPHA");
                                break block0;
                            }
                            case TEXT: {
                                throw new IllegalArgumentException("Text mask not supported on SDXL");
                            }
                        }
                        throw new IllegalArgumentException("Unsupported mask mode on SDXL: " + String.valueOf((Object)query.maskMode));
                    }
                    default: {
                        throw new IllegalArgumentException("Image edition mode not supported: " + String.valueOf((Object)query.originalImageEditionMode));
                    }
                }
                ob.with("init_image", query.originalImage);
                if (query.originalImageWeight != null) {
                    double imageStrength = query.originalImageWeight;
                    if (imageStrength < 0.0) {
                        logger.warn((Object)"Ignoring negative `originalImageWeight`, using 0 for `image_strength` instead.");
                        imageStrength = 0.0;
                    }
                    if (imageStrength > 1.0) {
                        logger.warn((Object)"Ignoring `originalImageWeight` above 1, using 1 for `image_strength` instead.");
                        imageStrength = 1.0;
                    }
                    ob.with("image_strength", (Number)imageStrength);
                }
            }
            if (query.nbImagesToGenerate != null) {
                ob.with("samples", (Number)query.nbImagesToGenerate);
            }
            if (query.height != null) {
                ob.with("height", (Number)query.height);
            }
            if (query.width != null) {
                ob.with("width", (Number)query.width);
            }
            if (query.seed != null) {
                ob.with("seed", (Number)query.seed);
            }
            if (query.quality != null) {
                if (ImageGenerationUtils.isHighQualitySynonym(query.quality)) {
                    ob.with("steps", (Number)50);
                } else if (!"standard".equals(query.quality)) {
                    if ("low".equals(query.quality)) {
                        ob.with("steps", (Number)10);
                    } else if (StringUtils.isNumeric((String)query.quality)) {
                        int steps = Integer.parseInt(query.quality);
                        if (steps < 10) {
                            logger.warn((Object)String.format("Ignoring steps '%d' below 10, using 10.", steps));
                            steps = 10;
                        } else if (steps > 150) {
                            logger.warn((Object)String.format("Ignoring steps '%d' below 150, using 150.", steps));
                            steps = 150;
                        }
                        ob.with("steps", (Number)steps);
                    } else {
                        logger.warn((Object)("Invalid quality setting for SDXL 1.0: " + query.quality));
                    }
                }
            }
            if (query.fidelity != null) {
                double cfgScale = query.fidelity * 35.0;
                ob.with("cfg_scale", (Number)cfgScale);
            }
            if (query.style != null) {
                ob.with("style_preset", query.style);
            }
            if (query.originalImage == null && query.maskImage == null) {
                logger.info((Object)("SDXL request: " + JSON.json((Object)ob.get())));
            }
            return ob.get();
        }

        @Override
        public LLMClient.ImageGenerationResponse parseImageGenerationResponse(JsonElement response, LLMClient.ImageGenerationQuery query) throws IOException {
            JsonObject jo = (JsonObject)response;
            LLMClient.ImageGenerationResponse ret = new LLMClient.ImageGenerationResponse();
            if (jo.has("result") && jo.get("result").isJsonPrimitive() && !"success".equals(jo.get("result").getAsString())) {
                throw new IOException("Stable Diffusion generation failed: " + JSON.json((Object)jo));
            }
            if (jo.has("artifacts") && jo.get("artifacts").isJsonArray()) {
                JsonArray images = jo.get("artifacts").getAsJsonArray();
                for (JsonElement elt : images) {
                    JsonObject oelt = (JsonObject)elt;
                    ret.images.add(new LLMClient.ImageGenerationImage(oelt.get("base64").getAsString()));
                }
            }
            return ret;
        }
    }

    public static class StabilityAIStableImageGenerationMarshall
    implements GenericImageGenerationLLMMarshall {
        private final String displayModelName;
        private final boolean supportsImage2Image;
        private static final AspectRatioMatcher aspectRatioMatcher = new AspectRatioMatcher("1:1", "16:9", "9:16", "21:9", "9:21", "2:3", "3:2", "4:5", "5:4");

        public StabilityAIStableImageGenerationMarshall(String displayModelName, boolean supportsImage2Image) {
            this.displayModelName = displayModelName;
            this.supportsImage2Image = supportsImage2Image;
        }

        @Override
        public JsonObject prepareImageGenerationQuery(LLMClient.ImageGenerationQuery query) {
            JF.ObjectBuilder ob = JF.obj();
            ob.with("prompt", query.getConcatenatedPrompts());
            ob.with("negative_prompt", query.getConcatenatedNegativePrompts());
            if (query.originalImage != null) {
                if (!this.supportsImage2Image) {
                    throw new IllegalArgumentException(String.format("originalImage specified but Bedrock %s does not support image to image mode.", this.displayModelName));
                }
                query.throwIfNullOriginalImageEditionMode();
                switch (query.originalImageEditionMode) {
                    case MASK_FREE: {
                        if (query.originalImageWeight != null) {
                            double originalImageWeight = query.originalImageWeight;
                            if (originalImageWeight >= 1.0) {
                                ob.with("strength", (Number)0);
                            } else if (originalImageWeight <= 0.0) {
                                ob.with("strength", (Number)1);
                            } else {
                                ob.with("strength", (Number)(1.0 - originalImageWeight));
                            }
                        }
                        ob.with("image", query.originalImage);
                        ob.with("mode", "image-to-image");
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Not supported edition mode: " + String.valueOf((Object)query.originalImageEditionMode));
                    }
                }
            }
            if (query.width != null && query.height != null) {
                ob.with("aspect_ratio", aspectRatioMatcher.fuzzyMatch(query.width, query.height));
            }
            if (query.seed != null) {
                ob.with("seed", (Number)query.seed);
            }
            if (query.style != null) {
                logger.warn((Object)String.format("style specified but ignored for Bedrock %s.", this.displayModelName));
            }
            if (query.fidelity != null) {
                logger.warn((Object)String.format("fidelity specified but ignored for Bedrock %s.", this.displayModelName));
            }
            return ob.get();
        }

        @Override
        public LLMClient.ImageGenerationResponse parseImageGenerationResponse(JsonElement response, LLMClient.ImageGenerationQuery query) throws IOException {
            JsonObject jo = (JsonObject)response;
            LLMClient.ImageGenerationResponse ret = new LLMClient.ImageGenerationResponse();
            if (!jo.has("finish_reasons") || !jo.get("finish_reasons").isJsonArray()) {
                throw new IOException(String.format("%s generation failed, bad response schema. Expected 'finish_reasons' array field, got:  %s", this.displayModelName, JSON.json((Object)jo)));
            }
            for (JsonElement reason : jo.getAsJsonArray("finish_reasons")) {
                if (reason.isJsonNull()) continue;
                throw new IOException(String.format("%s generation failed with finished reason '%s'. Got response: %s", this.displayModelName, reason, JSON.json((Object)jo)));
            }
            if (!jo.has("images") || !jo.get("images").isJsonArray()) {
                throw new IOException(String.format("%s generation failed, bad response schema. Expected 'images' array field, got:  %s", this.displayModelName, JSON.json((Object)jo)));
            }
            JsonArray images = jo.getAsJsonArray("images");
            if (Iterables.size((Iterable)images) != 1) {
                logger.warn((Object)("Expected a single image in the response but instead got: " + Iterables.size((Iterable)images)));
            }
            for (JsonElement elt : images) {
                ret.images.add(new LLMClient.ImageGenerationImage(elt.getAsString()));
            }
            ret.additionalInformation = new JsonObject();
            ret.additionalInformation.add("seeds", jo.get("seeds"));
            ret.additionalInformation.add("finish_reasons", jo.get("finish_reasons"));
            return ret;
        }
    }
}

