/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.marshall;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.NullChecker;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class CoreCompletionSettingsValidator {
    private final String apiQualifier;
    private final Set<CompletionSettingAccessor> supportedSettings = new LinkedHashSet<CompletionSettingAccessor>();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.marshall.settings_validator");

    public CoreCompletionSettingsValidator(String apiQualifier) {
        this.apiQualifier = apiQualifier;
    }

    public CoreCompletionSettingsValidator(String apiQualifier, CoreCompletionSettingsValidator base) {
        this(apiQualifier);
        this.supportedSettings.addAll(base.supportedSettings);
    }

    public CoreCompletionSettingsValidator allowMaxTokens() {
        return this.allow(CompletionSettingAccessor.MAX_TOKENS);
    }

    public CoreCompletionSettingsValidator allowTemperature() {
        return this.allow(CompletionSettingAccessor.TEMPERATURE);
    }

    public CoreCompletionSettingsValidator allowTopP() {
        return this.allow(CompletionSettingAccessor.TOP_P);
    }

    public CoreCompletionSettingsValidator allowTopK() {
        return this.allow(CompletionSettingAccessor.TOP_K);
    }

    public CoreCompletionSettingsValidator allowFrequencyPenalty() {
        return this.allow(CompletionSettingAccessor.FREQUENCY_PENALTY);
    }

    public CoreCompletionSettingsValidator allowPresencePenalty() {
        return this.allow(CompletionSettingAccessor.PRESENCE_PENALTY);
    }

    public CoreCompletionSettingsValidator allowLogProbs() {
        return this.allow(CompletionSettingAccessor.LOG_PROBS).allow(CompletionSettingAccessor.TOP_LOG_PROBS);
    }

    public CoreCompletionSettingsValidator allowJsonMode() {
        return this.allow(CompletionSettingAccessor.JSON_MODE);
    }

    public CoreCompletionSettingsValidator allowJsonSchema() {
        return this.allow(CompletionSettingAccessor.JSON_SCHEMA);
    }

    public CoreCompletionSettingsValidator allowLogitBias() {
        return this.allow(CompletionSettingAccessor.LOGIT_BIAS);
    }

    public CoreCompletionSettingsValidator allowStopSequences() {
        return this.allow(CompletionSettingAccessor.STOP_SEQUENCES);
    }

    public CoreCompletionSettingsValidator allowTools() {
        return this.allow(CompletionSettingAccessor.TOOL_CHOICE).allow(CompletionSettingAccessor.TOOLS);
    }

    private CoreCompletionSettingsValidator allow(CompletionSettingAccessor setting) {
        this.supportedSettings.add(setting);
        return this;
    }

    public void validate(CoreCompletionSettings ccs) {
        ArrayList<CompletionSettingAccessor> unsupportedSettings = new ArrayList<CompletionSettingAccessor>();
        for (CompletionSettingAccessor setting : CompletionSettingAccessor.values()) {
            if (this.supportedSettings.contains((Object)setting) || !setting.hasValue.test(ccs)) continue;
            unsupportedSettings.add(setting);
            setting.unsetValue.accept(ccs);
        }
        Optional<CompletionSettingAccessor> firstUnsupportedSetting = unsupportedSettings.stream().filter(s -> s.mustFailOnUnsupported).findFirst();
        if (firstUnsupportedSetting.isPresent()) {
            throw new IllegalArgumentException(String.format("The %s completion API does not support the following setting: %s", this.apiQualifier, firstUnsupportedSetting.get().name));
        }
        if (!unsupportedSettings.isEmpty()) {
            String names = unsupportedSettings.stream().map(i -> i.name).collect(Collectors.joining(", "));
            logger.warnV("The %s completion API does not support the following settings: %s. Unsupported settings are not included in subsequent API calls.", new Object[]{this.apiQualifier, names});
        }
        if (this.supportedSettings.contains((Object)CompletionSettingAccessor.TOOL_CHOICE) && Objects.nonNull(ccs.toolChoice)) {
            NullChecker.checkObject((Object)ccs.toolChoice);
        }
        if (this.supportedSettings.contains((Object)CompletionSettingAccessor.TOOLS) && Objects.nonNull(ccs.tools)) {
            NullChecker.checkObject(ccs.tools);
        }
    }

    private static enum CompletionSettingAccessor {
        MAX_TOKENS("maxTokens", s -> Objects.nonNull(s.maxTokens), s -> {
            s.maxTokens = null;
        }),
        TEMPERATURE("temperature", s -> Objects.nonNull(s.temperature), s -> {
            s.temperature = null;
        }),
        TOP_K("topK", s -> Objects.nonNull(s.topK), s -> {
            s.topK = null;
        }),
        TOP_P("topP", s -> Objects.nonNull(s.topP), s -> {
            s.topP = null;
        }),
        FREQUENCY_PENALTY("frequencyPenalty", s -> Objects.nonNull(s.frequencyPenalty), s -> {
            s.frequencyPenalty = null;
        }),
        PRESENCE_PENALTY("presencePenalty", s -> Objects.nonNull(s.presencePenalty), s -> {
            s.presencePenalty = null;
        }),
        LOG_PROBS("logProbs", s -> Objects.nonNull(s.logProbs), s -> {
            s.logProbs = null;
        }, true),
        TOP_LOG_PROBS("topLogProbs", s -> Objects.nonNull(s.topLogProbs), s -> {
            s.topLogProbs = null;
        }, true),
        LOGIT_BIAS("logitBias", s -> Objects.nonNull(s.logitBias) && !s.logitBias.isEmpty(), s -> {
            s.logitBias = null;
        }),
        STOP_SEQUENCES("stopSequences", s -> Objects.nonNull(s.stopSequences) && !s.stopSequences.isEmpty(), s -> {
            s.stopSequences = null;
        }),
        JSON_MODE("responseFormat (type=json)", s -> s.responseFormat instanceof LLMClient.ResponseFormatJson, s -> {
            s.responseFormat = null;
        }, true),
        JSON_SCHEMA("responseFormat (type=json, schema={...})", s -> s.responseFormat instanceof LLMClient.ResponseFormatJson && ((LLMClient.ResponseFormatJson)s.responseFormat).schema != null, s -> {
            if (s.responseFormat instanceof LLMClient.ResponseFormatJson) {
                ((LLMClient.ResponseFormatJson)s.responseFormat).schema = null;
            }
        }, true),
        TOOL_CHOICE("toolChoice", s -> Objects.nonNull(s.toolChoice), s -> {
            s.toolChoice = null;
        }, true),
        TOOLS("tools", s -> Objects.nonNull(s.tools) && !s.tools.isEmpty(), s -> {
            s.tools = null;
        }, true);

        private final String name;
        private final Predicate<CoreCompletionSettings> hasValue;
        private final Consumer<CoreCompletionSettings> unsetValue;
        private final boolean mustFailOnUnsupported;

        private CompletionSettingAccessor(String name, Predicate<CoreCompletionSettings> hasValue, Consumer<CoreCompletionSettings> unsetValue, boolean mustFailOnUnsupported) {
            this.name = name;
            this.hasValue = hasValue;
            this.unsetValue = unsetValue;
            this.mustFailOnUnsupported = mustFailOnUnsupported;
        }

        private CompletionSettingAccessor(String name, Predicate<CoreCompletionSettings> hasValue, Consumer<CoreCompletionSettings> unsetValue) {
            this(name, hasValue, unsetValue, false);
        }
    }
}

