/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.core;

import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.HashSet;
import java.util.SortedSet;
import java.util.TreeSet;

public class GpuConfig {
    private GpuParams params = new GpuParams();
    private HashSet<GpuSupportingCapability> disabledCapabilities = new HashSet();

    public void forceSingleGpu() {
        this.params.gpuList = new TreeSet<Integer>(Collections.singletonList(this.params.gpuList.first()));
    }

    public boolean shouldUseGpu(GpuSupportingCapability capability) {
        return this.params.useGpu && !this.disabledCapabilities.contains((Object)capability);
    }

    public void validate(MLTask.BackendType backendType) {
        if (this.params.useGpu) {
            ErrorContext.check((!this.params.gpuList.isEmpty() ? 1 : 0) != 0, (String)"Deactivate GPU execution or select at least one GPU.");
            if (MLTask.BackendType.KERAS.equals((Object)backendType)) {
                ErrorContext.check((this.params.perGPUMemoryFraction > 0.0f && this.params.perGPUMemoryFraction <= 1.0f ? 1 : 0) != 0, (String)"Memory allocation rate per GPU must be between 0 and 1.");
            }
        }
    }

    public static class GpuParams {
        public boolean useGpu = false;
        public SortedSet<Integer> gpuList = new TreeSet<Integer>(Collections.singletonList(0));
        public float perGPUMemoryFraction = 0.5f;
        public boolean gpuAllowGrowth = false;
    }

    public static enum GpuSupportingCapability {
        KERAS,
        GLUONTS,
        DEEP_HUB,
        DEEP_NN,
        SENTENCE_EMBEDDING,
        XGBOOST,
        TABICL;

    }
}

