/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance;

import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.llm.online.LLMClient;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.List;
import javax.annotation.Nullable;

public abstract class GuardrailRunner
implements AutoCloseable {
    public abstract void init() throws Exception;

    public abstract CompletionQueryGuardrailResponse processCompletionQuery(GuardrailContext var1, LLMClient.SingleCompletionQuery var2, LLMClient.LLMMeshTraceSpan var3) throws Exception;

    public abstract CompletionResponseGuardrailResponse processCompletionResponse(GuardrailContext var1, LLMClient.SingleCompletionQuery var2, LLMClient.SimpleCompletionResponseOrError var3, LLMClient.LLMMeshTraceSpan var4) throws Exception;

    public abstract LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer var1, GuardrailContext var2, LLMClient.SingleCompletionQuery var3, LLMClient.LLMMeshTraceSpan var4) throws Exception;

    public abstract EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailContext var1, LLMClient.EmbeddingQuery var2, LLMClient.LLMMeshTraceSpan var3) throws Exception;

    public abstract ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailContext var1, LLMClient.ImageGenerationQuery var2, LLMClient.LLMMeshTraceSpan var3) throws Exception;

    public abstract ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailContext var1, LLMClient.ImageGenerationQuery var2, LLMClient.ImageGenerationResponseOrError var3, LLMClient.LLMMeshTraceSpan var4) throws Exception;

    @Nullable
    private static JsonArray toJsonArrayOrNull(@Nullable JsonObject jsonObject) {
        if (jsonObject == null) {
            return null;
        }
        JsonArray array = new JsonArray();
        array.add((JsonElement)jsonObject);
        return array;
    }

    public static class ImageGenerationResponseGuardrailResponse
    extends ResponseGuardrailResponse {
        public LLMClient.ImageGenerationResponseOrError imageGenerationResponse;

        public static ImageGenerationResponseGuardrailResponse pass(GuardrailContext context, LLMClient.ImageGenerationResponseOrError imageGenerationResponse) {
            ImageGenerationResponseGuardrailResponse ret = new ImageGenerationResponseGuardrailResponse();
            ret.action = ResponseGuardrailAction.PASS;
            ret.imageGenerationResponse = imageGenerationResponse;
            ret.context = context;
            return ret;
        }

        public static ImageGenerationResponseGuardrailResponse fail(GuardrailContext context, Throwable t) {
            ImageGenerationResponseGuardrailResponse ret = new ImageGenerationResponseGuardrailResponse();
            ret.action = ResponseGuardrailAction.FAIL;
            ret.error = new SerializedError(t, false).withRememberedOriginalThrowable(t);
            ret.context = context;
            return ret;
        }

        public static ImageGenerationResponseGuardrailResponse passWithAudit(GuardrailContext context, LLMClient.ImageGenerationResponseOrError imageGenerationResponse, JsonObject auditData) {
            ImageGenerationResponseGuardrailResponse ret = new ImageGenerationResponseGuardrailResponse();
            ret.action = ResponseGuardrailAction.PASS_WITH_AUDIT;
            ret.imageGenerationResponse = imageGenerationResponse;
            ret.auditData = GuardrailRunner.toJsonArrayOrNull(auditData);
            ret.context = context;
            return ret;
        }
    }

    public static class ImageGenerationQueryGuardrailResponse
    extends QueryGuardrailResponse {
        public LLMClient.ImageGenerationQuery imageGenerationQuery;

        public static ImageGenerationQueryGuardrailResponse pass(GuardrailContext context, LLMClient.ImageGenerationQuery imageGenerationQuery) {
            ImageGenerationQueryGuardrailResponse ret = new ImageGenerationQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.PASS;
            ret.imageGenerationQuery = imageGenerationQuery;
            ret.context = context;
            return ret;
        }

        public static ImageGenerationQueryGuardrailResponse fail(GuardrailContext context, Throwable t) {
            ImageGenerationQueryGuardrailResponse ret = new ImageGenerationQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.FAIL;
            ret.error = new SerializedError(t, false).withRememberedOriginalThrowable(t);
            ret.context = context;
            return ret;
        }

        public static ImageGenerationQueryGuardrailResponse passWithAudit(GuardrailContext context, LLMClient.ImageGenerationQuery imageGenerationQuery, JsonObject auditData) {
            ImageGenerationQueryGuardrailResponse ret = new ImageGenerationQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.PASS_WITH_AUDIT;
            ret.imageGenerationQuery = imageGenerationQuery;
            ret.auditData = GuardrailRunner.toJsonArrayOrNull(auditData);
            ret.context = context;
            return ret;
        }
    }

    public static class EmbeddingQueryGuardrailResponse
    extends QueryGuardrailResponse {
        public LLMClient.EmbeddingQuery embeddingQuery;

        public static EmbeddingQueryGuardrailResponse pass(GuardrailContext context, LLMClient.EmbeddingQuery embeddingQuery) {
            EmbeddingQueryGuardrailResponse ret = new EmbeddingQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.PASS;
            ret.embeddingQuery = embeddingQuery;
            ret.context = context;
            return ret;
        }

        public static EmbeddingQueryGuardrailResponse fail(GuardrailContext context, Throwable t) {
            EmbeddingQueryGuardrailResponse ret = new EmbeddingQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.FAIL;
            ret.error = new SerializedError(t, false).withRememberedOriginalThrowable(t);
            ret.context = context;
            return ret;
        }

        public static EmbeddingQueryGuardrailResponse passWithAudit(GuardrailContext context, LLMClient.EmbeddingQuery embeddingQuery, JsonObject auditData) {
            EmbeddingQueryGuardrailResponse ret = new EmbeddingQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.PASS_WITH_AUDIT;
            ret.embeddingQuery = embeddingQuery;
            ret.auditData = GuardrailRunner.toJsonArrayOrNull(auditData);
            ret.context = context;
            return ret;
        }
    }

    public static class CompletionResponseGuardrailResponse
    extends ResponseGuardrailResponse {
        public LLMClient.SimpleCompletionResponseOrError completionResponse;

        public static CompletionResponseGuardrailResponse pass(GuardrailContext context, LLMClient.SimpleCompletionResponseOrError completionResponse) {
            CompletionResponseGuardrailResponse ret = new CompletionResponseGuardrailResponse();
            ret.action = ResponseGuardrailAction.PASS;
            ret.completionResponse = completionResponse;
            ret.context = context;
            return ret;
        }

        public static CompletionResponseGuardrailResponse fail(GuardrailContext context, Throwable t) {
            CompletionResponseGuardrailResponse ret = new CompletionResponseGuardrailResponse();
            ret.action = ResponseGuardrailAction.FAIL;
            ret.error = new SerializedError(t, false).withRememberedOriginalThrowable(t);
            ret.context = context;
            return ret;
        }

        public static CompletionResponseGuardrailResponse passWithAudit(GuardrailContext context, LLMClient.SimpleCompletionResponseOrError completionResponse, JsonObject auditData) {
            CompletionResponseGuardrailResponse ret = new CompletionResponseGuardrailResponse();
            ret.action = ResponseGuardrailAction.PASS_WITH_AUDIT;
            ret.completionResponse = completionResponse;
            ret.auditData = GuardrailRunner.toJsonArrayOrNull(auditData);
            ret.context = context;
            return ret;
        }
    }

    public static class CompletionQueryGuardrailResponse
    extends QueryGuardrailResponse {
        public LLMClient.SingleCompletionQuery completionQuery;

        public static CompletionQueryGuardrailResponse pass(GuardrailContext context, LLMClient.SingleCompletionQuery completionQuery) {
            CompletionQueryGuardrailResponse ret = new CompletionQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.PASS;
            ret.completionQuery = completionQuery;
            ret.context = context;
            return ret;
        }

        public static CompletionQueryGuardrailResponse fail(GuardrailContext context, Throwable t) {
            CompletionQueryGuardrailResponse ret = new CompletionQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.FAIL;
            ret.error = new SerializedError(t, false).withRememberedOriginalThrowable(t);
            ret.context = context;
            return ret;
        }

        public static CompletionQueryGuardrailResponse passWithAudit(GuardrailContext context, LLMClient.SingleCompletionQuery completionQuery, JsonObject auditData) {
            CompletionQueryGuardrailResponse ret = new CompletionQueryGuardrailResponse();
            ret.action = QueryGuardrailAction.PASS_WITH_AUDIT;
            ret.completionQuery = completionQuery;
            ret.context = context;
            ret.auditData = GuardrailRunner.toJsonArrayOrNull(auditData);
            return ret;
        }
    }

    public static class ResponseGuardrailResponse
    extends GuardrailResponse {
        public ResponseGuardrailAction action;
        public List<LLMClient.ChatMessage> updatedMessagesForRetry;
    }

    public static class QueryGuardrailResponse
    extends GuardrailResponse {
        public QueryGuardrailAction action;
        public String overriddenResponseText;
    }

    static abstract class GuardrailResponse {
        public GuardrailContext context;
        public JsonArray auditData;
        public SerializedError error;

        GuardrailResponse() {
        }

        public Exception toException() {
            if (this.error == null) {
                return new Exception("Unknown error");
            }
            if (this.error.originalThrowable != null && this.error.originalThrowable instanceof Exception) {
                return (Exception)this.error.originalThrowable;
            }
            return new Exception(this.error.detailedMessage);
        }
    }

    public static enum ResponseGuardrailAction {
        PASS,
        FAIL,
        PASS_WITH_AUDIT,
        RETRY,
        RESPOND;

    }

    public static enum QueryGuardrailAction {
        PASS,
        FAIL,
        PASS_WITH_AUDIT,
        RESPOND;

    }

    public static class GuardrailContext {
        JsonObject context = new JsonObject();
    }
}

