/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.eda.stats;

import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.eda.compute.computations.AvailableResult;
import com.dataiku.dip.eda.compute.computations.Computation;
import com.dataiku.dip.eda.compute.computations.ComputationResult;
import com.dataiku.dip.eda.compute.computations.common.MultiComputation;
import com.dataiku.dip.eda.compute.computations.univariate.PairwiseTTest;
import com.dataiku.dip.eda.compute.stats.AlternativeHypothesis;
import com.dataiku.dip.eda.compute.stats.PValueAdjustmentMethod;
import com.dataiku.dip.eda.compute.stats.VarianceAssumption;
import com.dataiku.dip.eda.worksheets.cards.fragments.BoxPlotFragment;
import com.dataiku.dip.recipes.eda.StatsTestRecipePayloadParams;
import com.dataiku.dip.recipes.eda.StatsTestRecipePayloadParamsWithAlternativeHypothesis;
import com.dataiku.dip.recipes.eda.stats.AbstractPairwiseTestStat;
import com.dataiku.dip.recipes.eda.stats.PairwiseTTestRecipePayloadParams;
import com.dataiku.dip.recipes.eda.stats.PopulationBoxPlot;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.commons.lang3.function.TriConsumer;

public class PairwiseTTestStat
extends AbstractPairwiseTestStat {
    private static final String PAIRWISE_T_TEST = "T_TEST_PAIRWISE";
    public VarianceAssumption varianceAssumption = VarianceAssumption.EQUAL;

    public PairwiseTTestStat(String testColumn, String groupByColumn, @Nullable List<String> groupValues, int maxPopulations, @Nullable String referenceGroup, VarianceAssumption varianceAssumption, PValueAdjustmentMethod adjustmentMethod) {
        super(testColumn, groupByColumn, groupValues, maxPopulations, referenceGroup, adjustmentMethod);
        this.varianceAssumption = varianceAssumption;
    }

    public PairwiseTTestStat() {
    }

    @Override
    protected void fillTestParams(Row row, ColumnFactory cf) {
        super.fillTestParams(row, cf);
        row.put(cf.column("test"), PAIRWISE_T_TEST);
        row.put(cf.column("variance_assumption"), this.varianceAssumption.name());
    }

    @Override
    protected MultiComputation getTestComputation(StatsTestRecipePayloadParams<?> payloadParams) {
        ArrayList<Computation> tests = new ArrayList<Computation>();
        for (AlternativeHypothesis alternative : payloadParams.as(PairwiseTTestRecipePayloadParams.class).getAlternatives()) {
            tests.add(new PairwiseTTest(this.testColumn, this.getPopulationGrouping(), this.referenceGroup != null, this.varianceAssumption, this.adjustmentMethod, alternative, payloadParams.confidenceLevel));
        }
        return new MultiComputation(tests);
    }

    @Override
    protected List<Row> extractTestRows(ComputationResult testResult, StatsTestRecipePayloadParams<?> payloadParams, List<PopulationBoxPlot> populationBoxPlots, @Nullable BoxPlotFragment mergedBoxPlot, Collection<String> boxPlotWarnings, Collection<String> boxPlotErrors, ColumnFactory cf, RowFactory rf) {
        List<StatsTestRecipePayloadParamsWithAlternativeHypothesis.AlternativeComputationResult> alternativeResults = payloadParams.as(PairwiseTTestRecipePayloadParams.class).buildResultsForAlternatives(testResult);
        double significanceLevel = payloadParams.significanceLevel();
        ArrayList<Row> rows = new ArrayList<Row>();
        for (StatsTestRecipePayloadParamsWithAlternativeHypothesis.AlternativeComputationResult alternativeResult : alternativeResults) {
            AlternativeHypothesis alternative = alternativeResult.alternative;
            ComputationResult tr = alternativeResult.result;
            LinkedHashSet<String> warningsForAlternative = new LinkedHashSet<String>(tr.collectWarnings());
            warningsForAlternative.addAll(boxPlotWarnings);
            LinkedHashSet<String> errorsForAlternative = new LinkedHashSet<String>(tr.collectErrors());
            errorsForAlternative.addAll(boxPlotErrors);
            int n = this.getNumberOfPopulations(tr, populationBoxPlots);
            if (n < 2) {
                Row row = this.rowBase(significanceLevel, warningsForAlternative, errorsForAlternative, cf, rf);
                this.fillPopulations(populationBoxPlots, mergedBoxPlot, 0, 1, row, cf);
                row.put(cf.column("alternative_hypothesis"), alternative.name());
                row.put(cf.column("alternative_hypothesis_explanation"), PairwiseTTestStat.getAlternativeExplanation(alternative));
                rows.add(row);
                continue;
            }
            this.iteratePairs(n, (TriConsumer<Integer, Integer, Integer>)((TriConsumer)(i, j, flatIndex) -> {
                Row row = this.rowBase(significanceLevel, warningsForAlternative, errorsForAlternative, cf, rf);
                this.fillPopulations(populationBoxPlots, mergedBoxPlot, (int)i, (int)j, row, cf);
                row.put(cf.column("alternative_hypothesis"), alternative.name());
                row.put(cf.column("alternative_hypothesis_explanation"), PairwiseTTestStat.getAlternativeExplanation(alternative));
                if (tr.isAvailable()) {
                    this.fillTestResult(tr.asAvailable(), significanceLevel, cf, row, (int)i, (int)j, (int)flatIndex);
                }
                rows.add(row);
            }));
        }
        return rows;
    }

    private static String getAlternativeExplanation(AlternativeHypothesis alternative) {
        switch (alternative) {
            case LOWER: {
                return "population 2 mean < population 1 mean";
            }
            case GREATER: {
                return "population 2 mean > population 1 mean";
            }
            case TWO_SIDED: {
                return "population 2 mean \u2260 population 1 mean";
            }
        }
        throw new IllegalArgumentException("Unknown alternative type");
    }

    @Override
    protected void fillTestResult(AvailableResult result, double significanceLevel, ColumnFactory cf, Row row, int i, int j, int rowIndex) {
        PairwiseTTest.PairwiseTTestResult tr = result.as(PairwiseTTest.PairwiseTTestResult.class);
        if (rowIndex >= tr.adjustedPvalues.length) {
            throw new IllegalArgumentException(String.format("Requested test index %d superior to number of available test results %d", rowIndex, tr.adjustedPvalues.length));
        }
        double pValue = tr.pvalues[rowIndex];
        double adjustedPValue = tr.adjustedPvalues[rowIndex];
        String conclusion = PairwiseTTestStat.conclusion(adjustedPValue, significanceLevel);
        row.put(cf.column("t_statistic"), tr.statistics[rowIndex]);
        row.put(cf.column("degrees_of_freedom"), tr.dofs[rowIndex]);
        row.put(cf.column("p_value"), pValue);
        row.put(cf.column("adjusted_p_value"), adjustedPValue);
        row.put(cf.column("conclusion"), conclusion);
    }
}

