/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ScoringRecipeUtils;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.model.CompatibilityWithReason;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.partitioning.PartitionFactory;
import com.dataiku.dip.scoring.exports.SQLPrediction;
import com.dataiku.dip.scoring.exports.SQLPreprocessing;
import com.dataiku.dip.scoring.exports.ScoringExporter;
import com.dataiku.dip.server.services.SingleWriteTransactionTransactionService;
import com.dataiku.dip.shaker.sql.FinalSchemaCaster;
import com.dataiku.dip.shaker.sql.SQLQueryWithSchema;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.SQLUtils;
import com.dataiku.dip.sql.queries.CombinedSelectQueryBuilder;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.ExpressionUtils;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.dip.utils.JSON;
import com.dataiku.scoring.builders.Build;
import com.dataiku.scoring.pipelines.BinaryProbabilisticPipeline;
import com.dataiku.scoring.pipelines.ClassificationPipeline;
import com.dataiku.scoring.pipelines.MulticlassProbabilisticPipeline;
import com.dataiku.scoring.pipelines.NonProbabilisticClassificationPipeline;
import com.dataiku.scoring.pipelines.RegressionPipeline;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class SQLScoring
implements ScoringExporter.ScoringWriter {
    public static final String CONTENT_TYPE = "text/plain";
    private final FullModelId fullModelId;
    String filename;
    SingleWriteTransactionTransactionService.DetransactionalizedCallable<String> queryCallable;

    public SQLScoring(FullModelId fullModelId, String filename, SingleWriteTransactionTransactionService.DetransactionalizedCallable<String> queryCallable) {
        this.fullModelId = fullModelId;
        this.filename = filename;
        this.queryCallable = queryCallable;
    }

    @Override
    public CompatibilityWithReason getCompatibility() throws IOException {
        return PredictionResultsReader.makeDetails((FullModelId)this.fullModelId).sqlCompatibility;
    }

    @Override
    public String getContentType() {
        return CONTENT_TYPE;
    }

    @Override
    public void writeTo(OutputStream os) throws Exception {
        OutputStreamWriter writer = new OutputStreamWriter(os);
        writer.write(this.queryCallable.call_NT());
        writer.close();
    }

    @Override
    public String makeFileName() {
        return this.filename;
    }

    public static String buildQuery(FullModelId fmi, SavedModel sm, Dataset inputDataset, List<Partition> inputPartitions, Dataset outputDataset, List<String> columnsToKeep, SQLUtils.SQLTable table, SQLDialect dialect, Double forcedClassifierThreshold) throws IOException, CodedException {
        File resources = fmi.getModelFolder().getAbsoluteFile();
        ResolvedClassicalPredictionPreprocessingParams rppp = (ResolvedClassicalPredictionPreprocessingParams)JSON.parseFile((File)new File(resources, "rpreprocessing_params.json"), ResolvedClassicalPredictionPreprocessingParams.class);
        ArrayList<String> featureColumns = new ArrayList<String>();
        for (Map.Entry e : rppp.per_feature.entrySet()) {
            if (((FeaturePreprocessingParams)e.getValue()).role != FeaturePreprocessingParams.Role.INPUT) continue;
            featureColumns.add((String)e.getKey());
        }
        Build.DssPipelineMeta meta = Build.pipelineMeta((URL)resources.toURI().toURL());
        if (meta.type == null) {
            throw new IOException("Failed to parse a valid type from the dss_pipeline_meta.json");
        }
        if (fmi.isPartitionedBaseModel()) {
            return SQLScoring.getPartitionedDispatchQuery(fmi, sm, inputDataset, outputDataset, columnsToKeep, table, dialect, rppp, featureColumns);
        }
        return SQLScoring.getNonPartitionedDispatchQuery(fmi, inputDataset, inputPartitions, outputDataset, columnsToKeep, table, dialect, resources, rppp, featureColumns, meta, forcedClassifierThreshold);
    }

    private static String getNonPartitionedDispatchQuery(FullModelId fmi, Dataset inputDataset, List<Partition> inputPartitions, Dataset outputDataset, List<String> columnsToKeep, SQLUtils.SQLTable table, SQLDialect dialect, File resources, ResolvedPredictionPreprocessingParams rppp, List<String> featureColumns, Build.DssPipelineMeta meta, Double forcedClassifierThreshold) throws IOException, CodedException {
        URL resourcesURL = resources.toURI().toURL();
        SQLQueryWithSchema preparedQWS = new SQLQueryWithSchema();
        preparedQWS.setDialect(dialect);
        preparedQWS.select("*");
        preparedQWS.from(table, "data");
        if (inputPartitions != null && !inputPartitions.isEmpty() && inputDataset.getPartitioningSchema() != null && inputDataset.getPartitioningSchema().isPartitioned()) {
            preparedQWS.where(ExpressionUtils.getPartitionFilterClause(inputDataset.getPartitioningSchema(), inputDataset, inputPartitions, dialect));
        }
        preparedQWS.initWithSchema(inputDataset.getSchema());
        SQLQueryWithSchema resultQWS = SQLScoring.getResultQWS(meta, preparedQWS, inputDataset, outputDataset.getSchema(), columnsToKeep, dialect, resourcesURL, rppp, featureColumns, fmi, forcedClassifierThreshold);
        return new FinalSchemaCaster().getCasted(resultQWS, outputDataset.getSchema()).applyInsertIntoCasts(outputDataset).toSQL(dialect);
    }

    private static SQLQueryWithSchema getResultQWS(Build.DssPipelineMeta meta, SQLQueryWithSchema preprocessedQWS, Dataset inputDataset, Schema outputSchema, List<String> columnsToKeep, SQLDialect dialect, URL resourcesURL, ResolvedPredictionPreprocessingParams rppp, List<String> featureColumns, FullModelId fmi, Double forcedClassifierThreshold) throws IOException {
        SelectQueryBuilder scoredSQB = new SelectQueryBuilder();
        String preprocessedTemporaryTableName = dialect.getSafeRandomTemporaryTableName("preprocessed");
        ClassicalPredictionModelDetails details = PredictionResultsReader.makeDetails(fmi);
        Double missingValue = null;
        Double unrecordedValue = 0.0;
        switch (details.modeling.algorithm) {
            case XGBOOST_CLASSIFICATION: 
            case XGBOOST_REGRESSION: {
                unrecordedValue = missingValue = details.modeling.xgboost_grid.impute_missing && !details.iperf.modelInputIsSparse ? Double.valueOf(details.modeling.xgboost_grid.missing) : null;
            }
        }
        switch (meta.type) {
            case REGRESSION: {
                scoredSQB = SQLScoring.buildRegressionQuery(resourcesURL, meta, rppp, preprocessedQWS.subQuery(preprocessedTemporaryTableName), featureColumns, columnsToKeep, dialect, inputDataset, unrecordedValue, missingValue);
                break;
            }
            case CLASSIFICATION_ONLY: {
                scoredSQB = SQLScoring.buildClassificationQuery(resourcesURL, meta, rppp, preprocessedQWS.subQuery(preprocessedTemporaryTableName), featureColumns, columnsToKeep, dialect, inputDataset, unrecordedValue, missingValue);
                break;
            }
            case BINARY_PROBABILISTIC: {
                scoredSQB = SQLScoring.buildBinaryQuery(resourcesURL, meta, rppp, preprocessedQWS.subQuery(preprocessedTemporaryTableName), featureColumns, columnsToKeep, dialect, inputDataset, forcedClassifierThreshold, unrecordedValue, missingValue);
                break;
            }
            case MULTICLASS_PROBABILISTIC: {
                scoredSQB = SQLScoring.buildMulticlassQuery(resourcesURL, meta, rppp, preprocessedQWS.subQuery(preprocessedTemporaryTableName), featureColumns, columnsToKeep, dialect, inputDataset, unrecordedValue, missingValue);
            }
        }
        if (ScoringRecipeUtils.ModelMetadataUtils.schemaIncludesModelMetadata(outputSchema) != null) {
            ScoringRecipeUtils.ModelMetadataUtils.sqlAddModelMetadata(scoredSQB, fmi);
        }
        Schema resultSchema = outputSchema.getCopy();
        for (SchemaColumn preprocessedColumn : preprocessedQWS.getCurrentSchema().getColumns()) {
            if (!outputSchema.hasColumn(preprocessedColumn.getName())) continue;
            resultSchema.updateOrAddColumn(preprocessedColumn);
        }
        SQLQueryWithSchema resultQWS = new SQLQueryWithSchema();
        resultQWS.setDialect(dialect);
        resultQWS.from(scoredSQB, "result");
        resultQWS.initWithSchema(resultSchema);
        return resultQWS;
    }

    private static String getPartitionedDispatchQuery(FullModelId fmi, SavedModel sm, Dataset inputDataset, Dataset outputDataset, List<String> columnsToKeep, SQLUtils.SQLTable table, SQLDialect dialect, ResolvedPredictionPreprocessingParams rppp, List<String> featureColumns) throws IOException, CodedException {
        CombinedSelectQueryBuilder union = CombinedSelectQueryBuilder.newUnion(false);
        Set<Map.Entry<String, URL>> partitions = fmi.getPartitionModelUrls().entrySet();
        SQLQueryWithSchema unmatchedPartitionsQuery = new SQLQueryWithSchema();
        unmatchedPartitionsQuery.setDialect(dialect);
        String temporaryUnmatchedEntriesTableName = dialect.getSafeRandomTemporaryTableName("data");
        List<String> inputColumns = inputDataset.getSchema().columns.stream().map(SchemaColumn::getName).collect(Collectors.toList());
        List<SchemaColumn> missingOutputColumns = outputDataset.getSchema().columns.stream().filter(sc -> !inputColumns.contains(sc.getName())).collect(Collectors.toList());
        inputColumns.forEach(inputColumn -> unmatchedPartitionsQuery.select((String)inputColumn));
        missingOutputColumns.forEach(missingOutputColumn -> unmatchedPartitionsQuery.select(new ExpressionBuilder.ExpressionBuilderFactory().nullValue(missingOutputColumn.getType(), 1), missingOutputColumn.getName()));
        unmatchedPartitionsQuery.from(table, temporaryUnmatchedEntriesTableName);
        for (Map.Entry<String, URL> partitionToURL : partitions) {
            String temporaryTableName = dialect.getSafeRandomTemporaryTableName("data");
            String partitionName = partitionToURL.getKey();
            URL partitionURl = partitionToURL.getValue();
            Partition partition = PartitionFactory.fromIdentifier(sm.getPartitioningSchema(), partitionName);
            URL partitionModelURL = fmi.getPartitionModelUrls().get(partitionName);
            Build.DssPipelineMeta partMeta = Build.pipelineMeta((URL)partitionURl);
            SQLQueryWithSchema preprocessedQWS = new SQLQueryWithSchema();
            preprocessedQWS.setDialect(dialect);
            preprocessedQWS.select("*");
            preprocessedQWS.from(table, temporaryTableName);
            ExpressionBuilder condition = ExpressionUtils.getPartitionFilterClause(sm.getPartitioningSchema(), inputDataset, partition, dialect);
            preprocessedQWS.where(condition);
            unmatchedPartitionsQuery.where(condition.not());
            preprocessedQWS.initWithSchema(inputDataset.getSchema());
            SQLQueryWithSchema resultQWS = SQLScoring.getResultQWS(partMeta, preprocessedQWS, inputDataset, outputDataset.getSchema(), columnsToKeep, dialect, partitionModelURL, rppp, featureColumns, fmi.getModelPartition(partitionToURL.getKey()), null);
            unmatchedPartitionsQuery.initWithSchema(resultQWS.getCurrentSchema());
            String temporaryCastedTableName = dialect.getSafeRandomTemporaryTableName("casted");
            union.add(new FinalSchemaCaster().getCasted(resultQWS, outputDataset.getSchema()).applyInsertIntoCasts(outputDataset).subQuery(temporaryCastedTableName));
        }
        String temporaryCastedTableName = dialect.getSafeRandomTemporaryTableName("casted");
        union.add(new FinalSchemaCaster().getCasted(unmatchedPartitionsQuery, outputDataset.getSchema()).applyInsertIntoCasts(outputDataset).subQuery(temporaryCastedTableName));
        return union.toSQL(dialect);
    }

    static SelectQueryBuilder buildRegressionQuery(URL resourcesURL, Build.DssPipelineMeta meta, ResolvedPredictionPreprocessingParams rppp, SelectQueryBuilder data, List<String> featureColumns, List<String> columnsToKeep, SQLDialect dialect, Dataset inputDataset, Double unrecordedValue, Double missingValue) throws IOException {
        RegressionPipeline pipe = Build.regressionPipeline((URL)resourcesURL, (Build.DssPipelineMeta)meta);
        SQLPreprocessing.SelectWithSchema prep = SQLPreprocessing.preprocessingQuery(dialect, inputDataset, pipe.getPreprocessing(), rppp, data, featureColumns, columnsToKeep, unrecordedValue);
        return SQLPrediction.regression(pipe, prep, missingValue);
    }

    static SelectQueryBuilder buildClassificationQuery(URL resourcesURL, Build.DssPipelineMeta meta, ResolvedPredictionPreprocessingParams rppp, SelectQueryBuilder data, List<String> featureColumns, List<String> columnsToKeep, SQLDialect dialect, Dataset inputDataset, Double unrecordedValue, Double missingValue) throws IOException {
        NonProbabilisticClassificationPipeline pipe = Build.nonProbabilisticClassificationPipeline((URL)resourcesURL, (Build.DssPipelineMeta)meta);
        SQLPreprocessing.SelectWithSchema prep = SQLPreprocessing.preprocessingQuery(dialect, inputDataset, pipe.getPreprocessing(), rppp, data, featureColumns, columnsToKeep, unrecordedValue);
        return SQLPrediction.classification((ClassificationPipeline)pipe, dialect, prep, missingValue);
    }

    static SelectQueryBuilder buildBinaryQuery(URL resourcesURL, Build.DssPipelineMeta meta, ResolvedPredictionPreprocessingParams rppp, SelectQueryBuilder data, List<String> featureColumns, List<String> columnsToKeep, SQLDialect dialect, Dataset inputDataset, Double forcedClassifierThreshold, Double unrecordedValue, Double missingValue) throws IOException {
        BinaryProbabilisticPipeline pipe = Build.binaryProbabilisticPipeline((URL)resourcesURL, (Build.DssPipelineMeta)meta, (Double)forcedClassifierThreshold);
        SQLPreprocessing.SelectWithSchema prep = SQLPreprocessing.preprocessingQuery(dialect, inputDataset, pipe.getPreprocessing(), rppp, data, featureColumns, columnsToKeep, unrecordedValue);
        return SQLPrediction.classification((ClassificationPipeline)pipe, dialect, prep, missingValue);
    }

    static SelectQueryBuilder buildMulticlassQuery(URL resourcesURL, Build.DssPipelineMeta meta, ResolvedPredictionPreprocessingParams rppp, SelectQueryBuilder data, List<String> featureColumns, List<String> columnsToKeep, SQLDialect dialect, Dataset inputDataset, Double unrecordedValue, Double missingValue) throws IOException {
        MulticlassProbabilisticPipeline pipe = Build.multiclassProbabilisticPipeline((URL)resourcesURL, (Build.DssPipelineMeta)meta);
        SQLPreprocessing.SelectWithSchema prep = SQLPreprocessing.preprocessingQuery(dialect, inputDataset, pipe.getPreprocessing(), rppp, data, featureColumns, columnsToKeep, unrecordedValue);
        SelectQueryBuilder scoredSQB = SQLPrediction.classification((ClassificationPipeline)pipe, dialect, prep, missingValue);
        SQLScoring.addPotentiallyMissingProbabilityColumns(scoredSQB, rppp.target_remapping);
        return scoredSQB;
    }

    private static void addPotentiallyMissingProbabilityColumns(SelectQueryBuilder scoredSQB, List<PredictionPreprocessingParams.MappingValue> targetRemapping) {
        List<String> selectedNames = scoredSQB.getSelectedNames();
        for (PredictionPreprocessingParams.MappingValue targetValue : targetRemapping) {
            String expectedProbaColumn = "proba_" + targetValue.sourceValue;
            if (selectedNames.contains(expectedProbaColumn)) continue;
            ExpressionBuilder.ExpressionBuilderFactory EBF = new ExpressionBuilder.ExpressionBuilderFactory();
            scoredSQB.select(EBF.cst(0.0), expectedProbaColumn);
        }
    }
}

