/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.sql;

import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.hadoop.HadoopFlavorUtils;
import com.dataiku.dip.partitioning.PartitioningScheme;
import com.dataiku.dip.sql.DatePart;
import com.dataiku.dip.sql.DateRounding;
import com.dataiku.dip.sql.GenericSQLDialect;
import com.dataiku.dip.sql.HiveLikeSQLDialect;
import com.dataiku.dip.sql.SQLAggregateAbility;
import com.dataiku.dip.sql.SQLAggregateType;
import com.dataiku.dip.sql.SQLCapability;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.QueryAst;
import com.dataiku.dip.sql.queries.QueryUtils;
import com.dataiku.dip.utils.DKUDateUtils;
import java.util.ArrayList;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang.StringUtils;

public class HiveSQLDialect
extends HiveLikeSQLDialect {
    int version = 12;

    public boolean supportsTopLevelUnion() {
        return this.version >= 13;
    }

    public boolean supportsUnion() {
        return this.version >= 102;
    }

    @Override
    public int getIdentifiersMaxLength() {
        return 128;
    }

    @Override
    public void initOperators() {
        super.initOperators();
        this.removeOperator(QueryUtils.OperatorType.REGEXP_SUBSTR);
        this.addGenericFunction(QueryUtils.OperatorType.STDDEV_SAMP, "STDDEV_SAMP", QueryUtils.Arity.UNARY);
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.PERCENTILE_APPROX_AGG, QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                double percentile = this.getParamAs(args[1], Double.class);
                return "percentile_approx(" + column + "," + percentile + ")";
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.DIV, "/", QueryUtils.Arity.BINARY, GenericSQLDialect.SQLPriority.TIMES.priority, false){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                return HiveSQLDialect.this.getDivisionClause(this.toSQLWithBracketsIfNeeded(args[0], GenericSQLDialect.SQLPriority.TIMES.priority), this.toSQLWithBracketsIfNeeded(args[1], GenericSQLDialect.SQLPriority.EQ.priority));
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.FLOAT_DIV, "/", QueryUtils.Arity.BINARY, GenericSQLDialect.SQLPriority.TIMES.priority, false){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                QueryAst.Expr castedArg = new ExpressionBuilder.ExpressionBuilderFactory().expr((QueryAst.Expr)args[0]).cast((Object[])new Object[]{Type.DOUBLE}).expr;
                return HiveSQLDialect.this.getDivisionClause(this.toSQLWithBracketsIfNeeded(castedArg, GenericSQLDialect.SQLPriority.TIMES.priority), this.toSQLWithBracketsIfNeeded(args[1], GenericSQLDialect.SQLPriority.EQ.priority));
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.AGG_CONCAT, "concat_ws", QueryUtils.Arity.TERNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateMinNumberOfParameters(args, 1);
                String column = this.toSQLNoBrackets(args[0]);
                String separator = null;
                boolean distinct = false;
                if (args.length > 1) {
                    QueryAst.ConstExpr separatorExpr = (QueryAst.ConstExpr)args[1];
                    String string = separator = separatorExpr == null ? null : this.toSQLNoBrackets(separatorExpr);
                }
                if (args.length > 2) {
                    QueryAst.ConstExpr distinctExpr = (QueryAst.ConstExpr)args[2];
                    boolean bl = distinct = distinctExpr == null ? false : (Boolean)distinctExpr.value;
                }
                if (separator == null) {
                    return "concat_ws('', " + (distinct ? "collect_set" : "collect_list") + "(CAST(" + column + " AS " + HiveSQLDialect.this.typeNameForCastAsString() + ")))";
                }
                return "concat_ws(" + separator + ", " + (distinct ? "collect_set" : "collect_list") + "(CAST(" + column + " AS " + HiveSQLDialect.this.typeNameForCastAsString() + ")))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.COLLECT_STRING_LIST, "collect_list", QueryUtils.Arity.UNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                return "collect_list(CAST(" + column + " AS " + HiveSQLDialect.this.typeNameForCastAsString() + "))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.COLLECT_STRING_SET, "collect_set", QueryUtils.Arity.UNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                return "collect_set(CAST(" + column + " AS " + HiveSQLDialect.this.typeNameForCastAsString() + "))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.ARRAY_TO_STRING, "concat_ws", QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String array = this.toSQLNoBrackets(args[0]);
                QueryAst.ConstExpr separatorExpr = (QueryAst.ConstExpr)args[1];
                String separator = (String)separatorExpr.value;
                if (separator == null) {
                    return "concat_ws('', " + array + ")";
                }
                return "concat_ws(" + this.toSQLNoBrackets(separatorExpr) + ", " + array + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.FIRST_VALUE, "FIRST_VALUE", QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                boolean ignoreNulls;
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                QueryAst.ConstExpr ignoreNullsExpr = (QueryAst.ConstExpr)args[1];
                boolean bl = ignoreNulls = ignoreNullsExpr.value == null ? false : (Boolean)ignoreNullsExpr.value;
                if (ignoreNulls) {
                    return "FIRST_VALUE(" + column + ", true)";
                }
                return "FIRST_VALUE(" + column + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.LAST_VALUE, "LAST_VALUE", QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                boolean ignoreNulls;
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                QueryAst.ConstExpr ignoreNullsExpr = (QueryAst.ConstExpr)args[1];
                boolean bl = ignoreNulls = ignoreNullsExpr.value == null ? false : (Boolean)ignoreNullsExpr.value;
                if (ignoreNulls) {
                    return "LAST_VALUE(" + column + ", true)";
                }
                return "LAST_VALUE(" + column + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.REGEX_LIKE, QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String input = this.toSQLNoBrackets(args[0]);
                String regex = this.toSQLNoBrackets(args[1]);
                return "REGEXP_LIKE(" + input + ", " + regex + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.PARSE, QueryUtils.Arity.NARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                Type requestedType;
                this.validateMinNumberOfParameters(args, 2);
                Object input = this.toSQLNoBrackets(args[0]);
                if (args[0].outputType != null && args[0].outputType.dssType != Type.STRING) {
                    input = "CAST(" + (String)input + " AS STRING)";
                }
                if ((requestedType = this.getParamAs(args[1], Type.class)).isTemporal()) {
                    String timezoneId;
                    this.validateMinNumberOfParameters(args, 3);
                    String jodaFormat = this.getParamAs(args[2], String.class);
                    String string = timezoneId = args.length > 4 ? this.getParamAs(args[4], String.class) : "UTC";
                    if (requestedType == Type.DATEONLY) {
                        String sqlFormat = HiveSQLDialect.this.toDateFormat(jodaFormat, true);
                        String converted = "UNIX_TIMESTAMP(" + (String)input + ",'" + sqlFormat + "')";
                        return "CAST(TO_UTC_TIMESTAMP(" + converted + " * 1000, 'UTC') AS DATE)";
                    }
                    if (requestedType == Type.DATETIMENOTZ) {
                        String sqlFormat = HiveSQLDialect.this.toDateFormat(jodaFormat, true);
                        String converted = "UNIX_TIMESTAMP(" + (String)input + ",'" + sqlFormat + "')";
                        return "TO_UTC_TIMESTAMP(" + converted + " * 1000, 'UTC')";
                    }
                    if (DKUDateUtils.isISO8601FormatString((String)jodaFormat)) {
                        String converted = "CAST(TRANSLATE(TRANSLATE(" + (String)input + ", 'T', ' '), 'Z', '') AS TIMESTAMP)";
                        return StringUtils.equals((String)timezoneId, (String)"UTC") ? converted : "TO_UTC_TIMESTAMP(" + converted + ", '" + timezoneId + "')";
                    }
                    String sqlFormat = HiveSQLDialect.this.toDateFormat(jodaFormat, true);
                    String converted = "UNIX_TIMESTAMP(" + (String)input + ",'" + sqlFormat + "')";
                    return "TO_UTC_TIMESTAMP(" + converted + " * 1000, '" + timezoneId + "')";
                }
                throw new NotImplementedException("parse as not date");
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.TRY_PARSE, QueryUtils.Arity.NARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateMinNumberOfParameters(args, 2);
                Type requestedType = this.getParamAs(args[1], Type.class);
                if (requestedType.isTemporal()) {
                    return HiveSQLDialect.this.getOperator(QueryUtils.OperatorType.PARSE).apply(args);
                }
                throw new NotImplementedException("parse as not date");
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.FORMAT, QueryUtils.Arity.NARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateMinNumberOfParameters(args, 2);
                Object input = this.toSQLNoBrackets(args[0]);
                Type requestedType = this.getParamAs(args[1], Type.class);
                if (requestedType.isTemporal()) {
                    this.validateMinNumberOfParameters(args, 3);
                    String jodaFormat = this.getParamAs(args[2], String.class);
                    String timezoneId = args.length > 4 ? this.getParamAs(args[4], String.class) : "UTC";
                    timezoneId = StringUtils.defaultIfBlank((String)timezoneId, (String)"UTC");
                    if (requestedType.isTimestamp() && StringUtils.isNotBlank((String)timezoneId) && !StringUtils.equals((String)"UTC", (String)timezoneId)) {
                        input = "from_utc_timestamp(" + (String)input + ", '" + timezoneId + "')";
                    }
                    String sqlFormat = HiveSQLDialect.this.toDateFormat(jodaFormat, false);
                    return "DATE_FORMAT(" + (String)input + ",'" + sqlFormat + "')";
                }
                throw new NotImplementedException("parse as not date");
            }
        });
    }

    @Override
    public String getDivisionClause(String numerator, String denominator) {
        return numerator + " / (CASE WHEN " + denominator + " = 0 THEN NULL ELSE " + denominator + " END)";
    }

    @Override
    public boolean supportsInDatabaseCharts() {
        return true;
    }

    @Override
    public boolean supportsNullsOrdering() {
        return true;
    }

    @Override
    public Map<SQLAggregateType, SQLAggregateAbility> getAggregationAbilities() {
        Map<SQLAggregateType, SQLAggregateAbility> abilities = super.getAggregationAbilities();
        abilities.put(SQLAggregateType.CONCAT, new SQLAggregateAbility(true, true, true, true));
        abilities.put(SQLAggregateType.CONCAT_DISTINCT, new SQLAggregateAbility(true, true, true, true));
        abilities.put(SQLAggregateType.FIRST_NOTNULL, new SQLAggregateAbility(true, true, true, true));
        abilities.put(SQLAggregateType.LAST_NOTNULL, new SQLAggregateAbility(true, true, true, true));
        return abilities;
    }

    @Override
    protected String dayOfWeekExpression(String expr) {
        if (HadoopFlavorUtils.isCDH7AtLeast(7, 1, 9)) {
            return "(cast(date_format(" + expr + ", 'e') as int) + 5) % 7";
        }
        return "cast(date_format(" + expr + ", 'u') as int) - 1";
    }

    @Override
    public String datePartExpression(String expr, DatePart part) {
        switch (part) {
            case SECOND_FROM_EPOCH: {
                return "unix_timestamp(" + expr + ")";
            }
            case MILLIS_FROM_EPOCH: {
                return "(unix_timestamp(" + expr + ") * 1000)";
            }
        }
        return super.datePartExpression(expr, part);
    }

    @Override
    public String dateTrunc(String inputDateExpression, DateRounding rounding) {
        switch (rounding) {
            case DAY: {
                return "cast(date_format(" + inputDateExpression + ", 'yyyy-MM-dd 00:00:00') as timestamp)";
            }
            case HOUR: {
                return "cast(date_format(" + inputDateExpression + ", 'yyyy-MM-dd HH:00:00') as timestamp)";
            }
            case MINUTE: {
                return "cast(date_format(" + inputDateExpression + ", 'yyyy-MM-dd HH:mm:00') as timestamp)";
            }
            case SECOND: {
                return "cast(date_format(" + inputDateExpression + ", 'yyyy-MM-dd HH:mm:ss') as timestamp)";
            }
            case MONTH: {
                return "cast(date_format(" + inputDateExpression + ", 'yyyy-MM-01') as timestamp)";
            }
            case YEAR: {
                return "cast(date_format(" + inputDateExpression + ", 'yyyy-01-01') as timestamp)";
            }
        }
        return super.dateTrunc(inputDateExpression, rounding);
    }

    @Override
    public String toDateFormatPart(DKUDateUtils.FormatPatternPart part, boolean forParsing, boolean hasIsoDatePart) {
        switch (part.type) {
            case ERA: {
                return "GG";
            }
            case YEAR: 
            case YEAROFERA: {
                return part.shortened ? "yy" : "yyyy";
            }
            case WEEK: {
                return "ww";
            }
            case WEEKYEAR: {
                return part.shortened ? "YY" : "YYYY";
            }
            case MONTH: {
                if (part.numeric) {
                    if (part.length == 1) {
                        return "M";
                    }
                    return "MM";
                }
                if (part.shortened) {
                    return "MMM";
                }
                return "MMMM";
            }
            case DAY: {
                if (part.length == 1) {
                    return "d";
                }
                return "dd";
            }
            case DAYOFYEAR: {
                return "DD";
            }
            case DAYOFWEEK: {
                if (part.numeric) {
                    return "F";
                }
                if (part.shortened) {
                    return "EEE";
                }
                return "EEEE";
            }
            case HOUR: {
                if (part.length == 1) {
                    return "H";
                }
                return "HH";
            }
            case HALFDAY: {
                return "a";
            }
            case HOUROFHALFDAY: {
                if (part.length == 1) {
                    return "K";
                }
                return "KK";
            }
            case CLOCKHOUR: {
                if (part.length == 1) {
                    return "k";
                }
                return "kk";
            }
            case CLOCKHOUROFHALFDAY: {
                if (part.length == 1) {
                    return "h";
                }
                return "hh";
            }
            case MINUTE: {
                if (part.length == 1) {
                    return "m";
                }
                return "mm";
            }
            case SECOND: {
                if (part.length == 1) {
                    return "s";
                }
                return "ss";
            }
            case MILLISECOND: {
                return "SSS";
            }
            case TIMEZONE: {
                if (forParsing) {
                    throw new IllegalArgumentException("Parsing with timezone format is not supported");
                }
                if (part.numeric) {
                    return "X";
                }
                return "z";
            }
            case TEXT: {
                Pattern toEscape = Pattern.compile("[a-zA-Z0-9\"'\\[\\]\\{\\}#]");
                if (toEscape.matcher(part.text).find()) {
                    return "\\'" + part.text.replace("'", "\\'\\'") + "\\'";
                }
                return part.text;
            }
        }
        return part.text;
    }

    @Override
    public SQLCapability canFormatDatePart(DKUDateUtils.FormatPatternPart part, boolean forParsing) {
        if (part.type == DKUDateUtils.FormatPatternPartType.TIMEZONE && forParsing) {
            return SQLCapability.nok("Cannot parse timezones");
        }
        if (part.type == DKUDateUtils.FormatPatternPartType.MILLISECOND && forParsing) {
            return SQLCapability.nok("Cannot parse milliseconds");
        }
        return SQLCapability.ok();
    }

    @Override
    public SQLCapability canFormatDate(String jodaFormat, boolean forParsing) {
        if (DKUDateUtils.isISO8601FormatString((String)jodaFormat)) {
            return SQLCapability.ok();
        }
        return super.canFormatDate(jodaFormat, forParsing);
    }

    @Override
    public String getId() {
        return "Hive";
    }

    @Override
    public SQLDialect.MaterializedTemporaryTableWriter getMaterializedTemporaryTableWriter() {
        return new SQLDialect.MaterializedTemporaryTableWriter(){

            @Override
            public String generateCreateTemp(SQLDialect.UpsertSpec spec) {
                ArrayList<String> columns = new ArrayList<String>();
                ArrayList partitioningColumns = new ArrayList();
                PartitioningScheme scheme = spec.targetDataset.getPartitioningSchema();
                if (scheme != null) {
                    partitioningColumns.addAll(scheme.getDimensionNames());
                }
                for (SchemaColumn sc : spec.targetDataset.getSchema().getColumns()) {
                    if (partitioningColumns.contains(sc.getName())) continue;
                    columns.add(sc.getName());
                }
                return String.format("CREATE TABLE %s AS SELECT %s FROM %s LIMIT 0", HiveSQLDialect.this.getQuotedTableFullName(spec.temp), columns.stream().map(c2 -> HiveSQLDialect.this.quoteIdentifier((String)c2)).collect(Collectors.joining(", ")), HiveSQLDialect.this.getQuotedTableFullName(spec.target));
            }

            @Override
            public String generateTruncateReal(SQLDialect.UpsertSpec spec, PartitioningScheme partitioningScheme) {
                return null;
            }

            @Override
            public String generateCopy(SQLDialect.UpsertSpec spec) {
                PartitioningScheme scheme = spec.targetDataset.getPartitioningSchema();
                if (scheme != null && scheme.isPartitioned()) {
                    ArrayList<String> columns = new ArrayList<String>();
                    ArrayList<String> partitionClauses = new ArrayList<String>();
                    for (String column : scheme.getDimensionNames()) {
                        partitionClauses.add(String.format("%s='$DKU_DST_%s'", HiveSQLDialect.this.quoteIdentifier(column), column));
                    }
                    for (SchemaColumn sc : spec.targetDataset.getSchema().getColumns()) {
                        if (scheme.getDimensionNames().contains(sc.getName())) continue;
                        columns.add(HiveSQLDialect.this.quoteIdentifier(sc.getName()));
                    }
                    return String.format("INSERT OVERWRITE TABLE %s PARTITION (%s) SELECT %s FROM %s", HiveSQLDialect.this.getQuotedTableFullName(spec.target), partitionClauses.stream().collect(Collectors.joining(", ")), columns.stream().collect(Collectors.joining(", ")), HiveSQLDialect.this.getQuotedTableFullName(spec.temp));
                }
                return String.format("INSERT OVERWRITE TABLE %s SELECT * FROM %s", HiveSQLDialect.this.getQuotedTableFullName(spec.target), HiveSQLDialect.this.getQuotedTableFullName(spec.temp));
            }

            @Override
            public String generateDropTemp(SQLDialect.UpsertSpec spec) {
                return String.format("DROP TABLE IF EXISTS %s", HiveSQLDialect.this.getQuotedTableFullName(spec.temp));
            }
        };
    }
}

