/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PredictionModelSnippetData;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.utils.AutoCloseableLock;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.NamedLock;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.springframework.stereotype.Service;

@Service
public class PartitionedExtractService {
    private final ConcurrentMap<String, PartitionedModelExtract> fmiToExtract = new ConcurrentHashMap<String, PartitionedModelExtract>();
    private final ConcurrentMap<String, TrainedFromJobModelInfo> jobToModelInfo = new ConcurrentHashMap<String, TrainedFromJobModelInfo>();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.prediction.part.cache");

    public void createOrRetrieve(FullModelId baseFmi, List<String> partitions) throws IOException {
        if (baseFmi.getPartFile().exists()) {
            this.read(baseFmi);
        } else {
            PartitionedModelExtract extract = new PartitionedModelExtract();
            extract.states.put(PartitionedModelExtract.PartitionState.PENDING, partitions.size());
            for (String partition : partitions) {
                FullModelId partitionFmi = baseFmi.getModelPartition(partition);
                extract.summaries.put(partition, new PartitionedModelExtract.PartitionedModelSummary(PartitionedModelExtract.PartitionState.PENDING, partitionFmi));
            }
            this.fmiToExtract.put(baseFmi.toString(), extract);
        }
    }

    public PartitionedModelExtract read(FullModelId baseFmi) throws IOException {
        assert (baseFmi.isPartitionedBaseModel());
        PartitionedModelExtract extract = (PartitionedModelExtract)this.fmiToExtract.get(baseFmi.toString());
        if (extract == null) {
            extract = baseFmi.getPartitionedModelExtract();
            this.fmiToExtract.put(baseFmi.toString(), extract);
        }
        return extract;
    }

    public void updateStates(final FullModelId baseFmi, final PartitionedModelExtract.PartitionState newState, final String ... partitionNames) {
        PartitionedExtractService.updateCacheAtomically(baseFmi, new Runnable(){

            @Override
            public void run() {
                PartitionedModelExtract extract = (PartitionedModelExtract)PartitionedExtractService.this.fmiToExtract.get(baseFmi.toString());
                if (extract != null && extract.summaries != null) {
                    for (String partitionName : partitionNames) {
                        PartitionedModelExtract.PartitionedModelSummary summary;
                        if (!extract.summaries.containsKey(partitionName) || (summary = extract.summaries.get(partitionName)) == null || summary.state == newState) continue;
                        extract.decreaseState(summary.state);
                        summary.state = newState;
                        extract.increaseState(newState);
                    }
                    PartitionedExtractService.this.dumpToFile(baseFmi, extract);
                }
            }
        });
    }

    public void updateAllStates(FullModelId baseFmi, PartitionedModelExtract.PartitionState newState) {
        PartitionedModelExtract extract = (PartitionedModelExtract)this.fmiToExtract.get(baseFmi.toString());
        this.updateStates(baseFmi, newState, extract.summaries.keySet().toArray(new String[0]));
    }

    public void updateSnippet(final FullModelId baseFmi, final PredictionModelSnippetData snippet) {
        PartitionedModelExtract tmpExtract;
        try {
            tmpExtract = this.read(baseFmi);
        }
        catch (IOException e) {
            logger.warn((Object)("Failed to read partitioned extract from model '" + String.valueOf(baseFmi) + "' and to update snippet data"), (Throwable)e);
            return;
        }
        final PartitionedModelExtract extract = tmpExtract;
        PartitionedExtractService.updateCacheAtomically(baseFmi, new Runnable(){

            @Override
            public void run() {
                PartitionedModelExtract.PartitionedModelSummary oldSummary = extract.summaries.get(snippet.partitionName);
                PartitionedModelExtract.PartitionedModelSummary newSummary = new PartitionedModelExtract.PartitionedModelSummary(snippet);
                if (oldSummary.state != newSummary.state) {
                    extract.decreaseState(oldSummary.state);
                    extract.increaseState(newSummary.state);
                }
                extract.summaries.put(snippet.partitionName, newSummary);
                PartitionedExtractService.this.dumpToFile(baseFmi, extract);
            }
        });
    }

    public void setPartitionedNonFinalStates(FullModelId baseFmi, PartitionedModelExtract.PartitionState stateIfNotDone) {
        PartitionedModelExtract extract = (PartitionedModelExtract)this.fmiToExtract.get(baseFmi.toString());
        ArrayList<String> notDonePartitions = new ArrayList<String>();
        if (extract != null) {
            for (Map.Entry<String, PartitionedModelExtract.PartitionedModelSummary> elem : extract.summaries.entrySet()) {
                if (elem.getValue().state == PartitionedModelExtract.PartitionState.DONE || elem.getValue().state == PartitionedModelExtract.PartitionState.FAILED) continue;
                notDonePartitions.add(elem.getKey());
            }
        }
        if (!notDonePartitions.isEmpty()) {
            this.updateStates(baseFmi, stateIfNotDone, notDonePartitions.toArray(new String[0]));
        }
    }

    int getNumPartitionsWithStates(FullModelId baseFmi, PartitionedModelExtract.PartitionState ... states) throws IOException {
        int count = 0;
        PartitionedModelExtract extract = this.read(baseFmi);
        if (extract != null) {
            for (PartitionedModelExtract.PartitionState state : states) {
                Integer stateCount = extract.states.get((Object)state);
                if (stateCount == null) continue;
                count += stateCount.intValue();
            }
        }
        return count;
    }

    public void dumpAndClear(final FullModelId baseFmi) {
        PartitionedExtractService.updateCacheAtomically(baseFmi, new Runnable(){

            @Override
            public void run() {
                PartitionedModelExtract extract = (PartitionedModelExtract)PartitionedExtractService.this.fmiToExtract.get(baseFmi.toString());
                PartitionedExtractService.this.dumpToFile(baseFmi, extract);
                PartitionedExtractService.this.fmiToExtract.remove(baseFmi.toString());
            }
        });
    }

    private void dumpToFile(FullModelId baseFmi, PartitionedModelExtract extract) {
        if (extract != null) {
            try {
                baseFmi.savePartFile(extract);
            }
            catch (IOException ex) {
                logger.warn((Object)"Failed to update partitioned train extract: ", (Throwable)ex);
            }
        }
    }

    private static void updateCacheAtomically(FullModelId baseFmi, Runnable runnable) {
        String lockName = "ml.partitioned.cache.fmi." + String.valueOf(baseFmi);
        try (AutoCloseableLock lock = NamedLock.acquire((String)lockName);){
            runnable.run();
        }
    }

    public static String getJobLockName(String jobId, SavedModel sm) {
        return String.format("ml.partitioned.cache.job.%s.%s", jobId, sm.getFullName());
    }

    private static String getJobToModelInfoKey(String jobId, SavedModel sm) {
        return jobId + "_" + sm.getFullId();
    }

    public boolean createIfNeeded(String jobId, SavedModel sm, FullModelId sourceModelFmi) throws IOException {
        String lockName = PartitionedExtractService.getJobLockName(jobId, sm);
        assert (NamedLock.isHeldByCurrentThread((String)lockName));
        TrainedFromJobModelInfo trainedFromJobModelInfo = (TrainedFromJobModelInfo)this.jobToModelInfo.get(PartitionedExtractService.getJobToModelInfoKey(jobId, sm));
        if (trainedFromJobModelInfo == null) {
            PartitionedModelExtract extract = sourceModelFmi != null ? sourceModelFmi.getPartitionedModelExtract() : new PartitionedModelExtract();
            trainedFromJobModelInfo = TrainedFromJobModelInfo.buildInitialModelInfo(extract);
            this.jobToModelInfo.put(PartitionedExtractService.getJobToModelInfoKey(jobId, sm), trainedFromJobModelInfo);
            return true;
        }
        return false;
    }

    public TrainedFromJobModelInfo get(String jobId, SavedModel sm) {
        return (TrainedFromJobModelInfo)this.jobToModelInfo.get(PartitionedExtractService.getJobToModelInfoKey(jobId, sm));
    }

    public void set(String jobId, SavedModel sm, PartitionedModelExtract extract, String jobIdVersion, int jobIdUpdate) {
        this.jobToModelInfo.put(PartitionedExtractService.getJobToModelInfoKey(jobId, sm), new TrainedFromJobModelInfo(extract, jobIdVersion, jobIdUpdate));
    }

    public static class TrainedFromJobModelInfo {
        public final PartitionedModelExtract extract;
        public final int jobIdUpdate;
        public final String jobIdVersion;

        TrainedFromJobModelInfo(PartitionedModelExtract extract, String jobIdVersion, int jobIdUpdate) {
            this.extract = extract;
            this.jobIdUpdate = jobIdUpdate;
            this.jobIdVersion = jobIdVersion;
        }

        static TrainedFromJobModelInfo buildInitialModelInfo(PartitionedModelExtract extract) {
            return new TrainedFromJobModelInfo(extract, null, 0);
        }
    }
}

