/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.endpoints.predict;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithEncryptedFields;
import com.dataiku.dip.connections.SQLConnectionProvider;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.datasets.sql.AbstractSQLDatasetHandler;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.controllers.NotFoundException;
import com.dataiku.dip.sql.SQLUtils;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.QueryAst;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.lambda.ServiceGenContext;
import com.dataiku.lambda.dataload.DataLoadManager;
import com.dataiku.lambda.dataload.LoadedDataMapping;
import com.dataiku.lambda.endpoints.predictcommon.PipelineMessage;
import com.dataiku.lambda.model.api.PredictionResponse;
import com.dataiku.lambda.model.serverconfig.DatasetResource;
import com.dataiku.lambda.model.serverconfig.PredictionFeaturesLeftJoinMapping;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class SQLLeftJoinEnrichStep {
    @Autowired
    private DataLoadManager loadManager;
    @Autowired
    private PasswordEncryptionService cryptoService;
    private PredictionFeaturesLeftJoinMapping mapping;
    private DatasetResource resource;
    private SQLConnectionProvider.SQLConnectionData connData;
    private String remappedTable;
    private SQLConnectionProvider.SQLConnectionWrapper currentConn;
    private SQLUtils.SQLTable tableDesc;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.lambda.predict.join");

    private SQLConnectionProvider.SQLConnectionWrapper acquireConnection() throws SQLException, DKUSecurityException, InterruptedException {
        if (this.currentConn != null) {
            SQLConnectionProvider.SQLConnectionWrapper sQLConnectionWrapper;
            block9: {
                logger.info((Object)"Verifying connection");
                ExpressionBuilder.ExpressionBuilderFactory ef = new ExpressionBuilder.ExpressionBuilderFactory();
                SelectQueryBuilder qb = new SelectQueryBuilder();
                qb.select(ef.cst((Object)1));
                String select1 = qb.toSQL(this.connData.getDialect());
                Statement st = this.currentConn.createStatement();
                try {
                    st.execute(select1);
                    logger.info((Object)"Connection OK");
                    sQLConnectionWrapper = this.currentConn;
                    if (st == null) break block9;
                }
                catch (Throwable throwable) {
                    try {
                        if (st != null) {
                            try {
                                st.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (Exception e) {
                        logger.warn((Object)"Connection is not alive anymore, dropping it", (Throwable)e);
                        this.currentConn = null;
                    }
                }
                st.close();
            }
            return sQLConnectionWrapper;
        }
        this.currentConn = SQLConnectionProvider.newConnection((SQLConnectionProvider.SQLConnectionData)this.connData, (AuthCtx)this.getUser(), null);
        SQLUtils.unsafeSetAutoCommit((SQLConnectionProvider.SQLConnectionWrapper)this.currentConn, (boolean)true);
        return this.currentConn;
    }

    public SQLLeftJoinEnrichStep(PredictionFeaturesLeftJoinMapping mapping, DatasetResource resource) {
        this.mapping = mapping;
        this.resource = resource;
    }

    public void init(ServiceGenContext ctx) throws Exception {
        SpringUtils.getInstance().autowire((Object)this);
        AuthCtx user = this.getUser();
        if (this.resource.packagingType == DatasetResource.PackagingType.BUNDLED_TOCONNECTION) {
            LoadedDataMapping.DataGen dg = this.loadManager.getMapped(ctx.getServiceId(), this.resource.resourceId, ctx.getGenerationId());
            this.remappedTable = dg.tableLike;
            logger.info((Object)("Using remapped table " + this.remappedTable));
            this.connData = SQLConnectionProvider.getConnectionData_NT((AuthCtx)user, null, (String)"___dku_bundled");
            AbstractSQLConnection.SQLManagedDatasetNamingRule namingRule = this.connData.getConnection().getParams().namingRule;
            this.tableDesc = new SQLUtils.SQLTable(namingRule.catalog, namingRule.schemaName, this.remappedTable, true);
        } else {
            AbstractSQLDatasetHandler.AbstractSQLConfig config = (AbstractSQLDatasetHandler.AbstractSQLConfig)this.resource.dssDataset.getParams();
            this.remappedTable = config.table;
            this.connData = SQLConnectionProvider.getConnectionData_NT((AuthCtx)user, (SerializedDataset)this.resource.dssDataset);
            this.tableDesc = new SQLUtils.SQLTable(config.catalog, config.schema, this.remappedTable, true);
        }
        if (this.connData.getConnection() instanceof ConnectionWithEncryptedFields) {
            ((ConnectionWithEncryptedFields)this.connData.getConnection()).decryptFields(this.cryptoService);
        }
    }

    private AuthCtx getUser() {
        DSSAuthCtx user = ApplicationConfigurator.isDevLambdaServer() ? DSSAuthCtx.forUserTestsWithLambdaDevServer() : DSSAuthCtx.newNone();
        return user;
    }

    public void destroy() {
        if (this.currentConn != null) {
            try {
                this.currentConn.close();
            }
            catch (SQLException sQLException) {
            }
            finally {
                this.currentConn = null;
            }
        }
    }

    private String generateOneQuery(MemTable mt, int rowIdx, MemRow row) throws MissingLookupKeyException {
        QueryAst.TableLike tl = SelectQueryBuilder.table((SQLUtils.SQLTable)this.tableDesc, (String)this.tableDesc.getTable());
        SelectQueryBuilder select = new SelectQueryBuilder();
        select.from(tl);
        select.limit(Long.valueOf(2L));
        ExpressionBuilder.ExpressionBuilderFactory ef = new ExpressionBuilder.ExpressionBuilderFactory();
        for (PredictionFeaturesLeftJoinMapping.KeyEltMatching keyElt : this.mapping.on) {
            ExpressionBuilder colCond = ef.col(keyElt.resourceLookupCol);
            colCond = this.applyTimezoneConversion(keyElt.resourceLookupCol, this.resource.dssDataset, colCond);
            if (StringUtils.isEmpty((String)keyElt.inputLookupKey)) {
                keyElt.inputLookupKey = keyElt.resourceLookupCol;
            }
            String value = row.get((Column)mt.column(keyElt.inputLookupKey));
            logger.info((Object)("Lookup " + keyElt.inputLookupKey + " -> " + value));
            if (value == null) {
                throw new MissingLookupKeyException();
            }
            select.where(new ExpressionBuilder[]{colCond.eq((Object)value)});
        }
        for (String tableCol : this.mapping.columnsMapping.keySet()) {
            ExpressionBuilder col = ef.col(tableCol);
            col = this.applyTimezoneConversion(tableCol, this.resource.dssDataset, col);
            select.select(col, tableCol);
        }
        return select.toSQL(this.connData.getDialect());
    }

    private ExpressionBuilder applyTimezoneConversion(String columnName, SerializedDataset source, ExpressionBuilder col) {
        Schema schema = source.getSchema();
        if (schema == null) {
            throw ErrorContext.iaef((String)"Dataset '%s' has no schema", (Object)source.name, (Object[])new Object[0]);
        }
        SchemaColumn sc = schema.getColumn(columnName);
        if (sc != null && sc.getType() == Type.DATE && sc.timestampNoTzAsDate && !source.managed && source.getParams() instanceof AbstractSQLDatasetHandler.AbstractSQLConfig) {
            String assumedDBTzForUnknownTz = ((AbstractSQLDatasetHandler.AbstractSQLConfig)source.getParamsAs(AbstractSQLDatasetHandler.AbstractSQLConfig.class)).getAssumedDBTzForUnknownTz();
            col = col.convertFromTz(assumedDBTzForUnknownTz);
        }
        return col;
    }

    private void enrichOneRow(PipelineMessage message, SQLConnectionProvider.SQLConnectionWrapper conn, int rowIdx, MemRow row) throws SQLException, MissingLookupKeyException, NotFoundException {
        try {
            String sqlQuery = this.generateOneQuery(message.table, rowIdx, row);
            logger.trace(() -> "Enrich with query: " + sqlQuery);
            try (Statement st = conn.createStatement();){
                boolean ret = st.execute(sqlQuery);
                assert (ret);
                ResultSet rs = st.getResultSet();
                ResultSetMetaData rsmd = rs.getMetaData();
                if (!rs.next()) {
                    switch (this.mapping.notFoundBehavior) {
                        case DROP_ROW: {
                            message.prePredictIgnoreReasons.set(rowIdx, PredictionResponse.IgnoreReason.NO_LOOKUP_MATCH);
                            return;
                        }
                        case ERROR: {
                            throw new RuntimeException("Lookup value not found for " + sqlQuery);
                        }
                        case IGNORE: {
                            logger.info((Object)"No result ... ignoring");
                            return;
                        }
                    }
                }
                for (String tableCol : this.mapping.columnsMapping.keySet()) {
                    MemColumn mc = message.table.column((String)this.mapping.columnsMapping.get(tableCol));
                    String val = rs.getString(tableCol);
                    row.put((Column)mc, val);
                    int index = rs.findColumn(tableCol);
                    message.columnsSQLTypes.put(mc.getName(), rsmd.getColumnType(index));
                }
                if (rs.next()) {
                    switch (this.mapping.multimatchBehavior) {
                        case DROP_ROW: {
                            message.prePredictIgnoreReasons.set(rowIdx, PredictionResponse.IgnoreReason.MULTIPLE_LOOKUP_MATCHES);
                            return;
                        }
                        case ERROR: {
                            throw new RuntimeException("Multiple matches for " + sqlQuery);
                        }
                        case KEEP_FIRST: {
                            logger.info((Object)"Multiple results, keeping first ...");
                            return;
                        }
                    }
                }
            }
        }
        catch (MissingLookupKeyException e) {
            switch (this.mapping.missingLookupKeyBehavior) {
                case DROP_ROW: {
                    message.prePredictIgnoreReasons.set(rowIdx, PredictionResponse.IgnoreReason.MISSING_LOOKUP_KEY);
                    break;
                }
                case ERROR: {
                    throw e;
                }
            }
        }
    }

    public void process(PipelineMessage message) throws Exception {
        SQLConnectionProvider.SQLConnectionWrapper conn = this.acquireConnection();
        int rowIdx = 0;
        for (MemRow row : message.table.rows) {
            this.enrichOneRow(message, conn, rowIdx, row);
            ++rowIdx;
        }
        logger.trace(() -> "MemTable after enrich:\n" + message.table.dumpToString());
    }

    static class MissingLookupKeyException
    extends Exception {
        private static final long serialVersionUID = 1L;

        MissingLookupKeyException() {
        }
    }
}

