/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.calcite.plan.rule;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.rel.LogicalDedup;
import org.opensearch.sql.calcite.plan.rule.ImmutablePPLSimplifyDedupRule;
import org.opensearch.sql.calcite.plan.rule.OpenSearchRuleConfig;
import org.opensearch.sql.calcite.utils.PlanUtils;

@Value.Enclosing
public class PPLSimplifyDedupRule
extends RelRule<Config> {
    public static final PPLSimplifyDedupRule DEDUP_SIMPLIFY_RULE = Config.DEFAULT.toRule();

    protected PPLSimplifyDedupRule(Config config) {
        super((RelRule.Config)config);
    }

    public void onMatch(RelOptRuleCall call) {
        LogicalProject finalProject = (LogicalProject)call.rel(0);
        LogicalFilter numOfDedupFilter = (LogicalFilter)call.rel(1);
        LogicalProject projectWithWindow = (LogicalProject)call.rel(2);
        LogicalFilter bucketNonNullFilter = (LogicalFilter)call.rel(3);
        this.apply(call, finalProject, numOfDedupFilter, projectWithWindow, bucketNonNullFilter);
    }

    protected void apply(RelOptRuleCall call, LogicalProject finalProject, LogicalFilter numOfDedupFilter, LogicalProject projectWithWindow, LogicalFilter bucketNonNullFilter) {
        List<RexWindow> windows = PlanUtils.getRexWindowFromProject(projectWithWindow);
        if (windows.size() != 1) {
            return;
        }
        ImmutableList dedupColumns = windows.get((int)0).partitionKeys;
        if (dedupColumns.stream().filter(rex -> rex.isA(SqlKind.INPUT_REF)).anyMatch(rex -> rex.getType().getSqlTypeName() == SqlTypeName.MAP || rex.getType().getSqlTypeName() == SqlTypeName.ARRAY)) {
            return;
        }
        RexNode condition = numOfDedupFilter.getCondition();
        if (!(condition instanceof RexCall)) {
            return;
        }
        List operands = ((RexCall)condition).getOperands();
        if (operands.isEmpty()) {
            return;
        }
        RexNode lastOperand = (RexNode)operands.get(operands.size() - 1);
        if (!(lastOperand instanceof RexLiteral)) {
            return;
        }
        RexLiteral literal = (RexLiteral)lastOperand;
        Integer dedupNumber = (Integer)literal.getValueAs(Integer.class);
        if (dedupNumber == null) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(bucketNonNullFilter.getInput());
        List targetProjections = projectWithWindow.getNamedProjects().stream().filter(p -> !((RexNode)p.getKey()).isA(SqlKind.ROW_NUMBER)).collect(Collectors.toList());
        relBuilder.project((Iterable)targetProjections.stream().map(Pair::getKey).collect(Collectors.toList()), (Iterable)targetProjections.stream().map(Pair::getValue).collect(Collectors.toList()));
        LogicalDedup dedup = LogicalDedup.create(relBuilder.build(), (List<RexNode>)dedupColumns, dedupNumber, false, false);
        relBuilder.push((RelNode)dedup);
        relBuilder.project((Iterable)finalProject.getProjects(), (Iterable)finalProject.getRowType().getFieldNames());
        call.transformTo(relBuilder.build());
    }

    @Value.Immutable
    public static interface Config
    extends OpenSearchRuleConfig {
        public static final Config DEFAULT = ImmutablePPLSimplifyDedupRule.Config.builder().build().withOperandSupplier(b0 -> b0.operand(LogicalProject.class).predicate(Predicate.not(PlanUtils::containsRowNumberDedup)).oneInput(b1 -> b1.operand(LogicalFilter.class).predicate(Config::validDedupNumberChecker).oneInput(b2 -> b2.operand(LogicalProject.class).predicate(PlanUtils::containsRowNumberDedup).oneInput(b3 -> b3.operand(LogicalFilter.class).predicate(PlanUtils::mayBeFilterFromBucketNonNull).anyInputs()))));

        default public PPLSimplifyDedupRule toRule() {
            return new PPLSimplifyDedupRule(this);
        }

        private static boolean validDedupNumberChecker(LogicalFilter filter) {
            return filter.getCondition().isA(SqlKind.LESS_THAN_OR_EQUAL) && PlanUtils.containsRowNumberDedup((RelNode)filter);
        }

        private static boolean isNullOrLessThan(RexNode node) {
            if (node.isA(SqlKind.LESS_THAN_OR_EQUAL)) {
                return true;
            }
            if (!node.isA(SqlKind.OR)) {
                return false;
            }
            boolean hasLessThan = false;
            for (RexNode operand : ((RexCall)node).getOperands()) {
                if (operand.isA(SqlKind.LESS_THAN_OR_EQUAL)) {
                    if (hasLessThan) {
                        return false;
                    }
                    hasLessThan = true;
                    continue;
                }
                if (operand.isA(SqlKind.IS_NULL)) continue;
                return false;
            }
            return hasLessThan;
        }
    }
}

