/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.physical.stream;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCalc;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalChangelogNormalize;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange;
import org.apache.flink.table.planner.plan.rules.physical.stream.ImmutablePushCalcPastChangelogNormalizeRule;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;
import org.immutables.value.Value;

@Internal
@Value.Enclosing
public class PushCalcPastChangelogNormalizeRule
extends RelRule<Config> {
    public static final RelOptRule INSTANCE = new PushCalcPastChangelogNormalizeRule(Config.DEFAULT);

    public PushCalcPastChangelogNormalizeRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        StreamPhysicalCalc calc = (StreamPhysicalCalc)call.rel(0);
        StreamPhysicalChangelogNormalize changelogNormalize = (StreamPhysicalChangelogNormalize)call.rel(1);
        Set<Integer> primaryKeyIndices = IntStream.of(changelogNormalize.uniqueKeys()).boxed().collect(Collectors.toSet());
        ArrayList<RexNode> primaryKeyPredicates = new ArrayList<RexNode>();
        ArrayList<RexNode> otherPredicates = new ArrayList<RexNode>();
        RexProgram program = calc.getProgram();
        if (program.getCondition() != null) {
            RexNode condition = RexUtil.toCnf(call.builder().getRexBuilder(), program.expandLocalRef(program.getCondition()));
            this.partitionPrimaryKeyPredicates(RelOptUtil.conjunctions(condition), primaryKeyIndices, primaryKeyPredicates, otherPredicates);
        }
        if (changelogNormalize.filterCondition() != null) {
            otherPredicates.add(changelogNormalize.filterCondition());
        }
        int[] usedInputFields = this.extractUsedInputFields(calc, changelogNormalize, primaryKeyIndices);
        StreamPhysicalChangelogNormalize newChangelogNormalize = this.pushCalcThroughChangelogNormalize(call, primaryKeyPredicates, otherPredicates, usedInputFields);
        this.transformWithRemainingPredicates(call, newChangelogNormalize, usedInputFields);
    }

    private int[] extractUsedInputFields(StreamPhysicalCalc calc, StreamPhysicalChangelogNormalize changelogNormalize, Set<Integer> primaryKeyIndices) {
        RexProgram program = calc.getProgram();
        List<RexNode> projectsAndCondition = program.getProjectList().stream().map(program::expandLocalRef).collect(Collectors.toList());
        if (program.getCondition() != null) {
            projectsAndCondition.add(program.expandLocalRef(program.getCondition()));
        }
        if (changelogNormalize.filterCondition() != null) {
            projectsAndCondition.add(changelogNormalize.filterCondition());
        }
        Set projectedFields = Arrays.stream(RexNodeExtractor.extractRefInputFields(projectsAndCondition)).boxed().collect(Collectors.toSet());
        projectedFields.addAll(primaryKeyIndices);
        return projectedFields.stream().sorted().mapToInt(Integer::intValue).toArray();
    }

    private void partitionPrimaryKeyPredicates(List<RexNode> predicates, Set<Integer> primaryKeyIndices, List<RexNode> primaryKeyPredicates, List<RexNode> remainingPredicates) {
        for (RexNode predicate : predicates) {
            int[] inputRefs = RexNodeExtractor.extractRefInputFields(Collections.singletonList(predicate));
            if (Arrays.stream(inputRefs).allMatch(primaryKeyIndices::contains)) {
                primaryKeyPredicates.add(predicate);
                continue;
            }
            remainingPredicates.add(predicate);
        }
    }

    private StreamPhysicalChangelogNormalize pushCalcThroughChangelogNormalize(RelOptRuleCall call, List<RexNode> primaryKeyPredicates, List<RexNode> otherPredicates, int[] usedInputFields) {
        StreamPhysicalChangelogNormalize changelogNormalize = (StreamPhysicalChangelogNormalize)call.rel(1);
        StreamPhysicalExchange exchange = (StreamPhysicalExchange)call.rel(2);
        Set primaryKeyIndices = IntStream.of(changelogNormalize.uniqueKeys()).boxed().collect(Collectors.toSet());
        if (primaryKeyPredicates.isEmpty() && usedInputFields.length == changelogNormalize.getRowType().getFieldCount()) {
            if (otherPredicates.isEmpty()) {
                return changelogNormalize;
            }
            RexNode condition = call.builder().and(otherPredicates);
            return (StreamPhysicalChangelogNormalize)changelogNormalize.copy(changelogNormalize.getTraitSet(), exchange, changelogNormalize.uniqueKeys(), condition.isAlwaysTrue() ? null : condition);
        }
        StreamPhysicalCalc pushedCalc = this.projectUsedFieldsWithConditions(call.builder(), exchange.getInput(), primaryKeyPredicates, usedInputFields);
        Map<Integer, Integer> inputRefMapping = this.buildFieldsMapping(usedInputFields);
        List newPrimaryKeyIndices = primaryKeyIndices.stream().map(inputRefMapping::get).collect(Collectors.toList());
        List shiftedPredicates = otherPredicates.stream().map(p -> this.adjustInputRef((RexNode)p, inputRefMapping)).collect(Collectors.toList());
        RexNode condition = call.builder().and(shiftedPredicates);
        FlinkRelDistribution newDistribution = FlinkRelDistribution.hash(newPrimaryKeyIndices, true);
        RelTraitSet newTraitSet = exchange.getTraitSet().replace(newDistribution);
        StreamPhysicalExchange newExchange = exchange.copy(newTraitSet, pushedCalc, newDistribution);
        return (StreamPhysicalChangelogNormalize)changelogNormalize.copy(changelogNormalize.getTraitSet(), newExchange, newPrimaryKeyIndices.stream().mapToInt(Integer::intValue).toArray(), condition.isAlwaysTrue() ? null : condition);
    }

    private StreamPhysicalCalc projectUsedFieldsWithConditions(RelBuilder relBuilder, RelNode input, List<RexNode> conditions, int[] usedFields) {
        RelDataType inputRowType = input.getRowType();
        List<String> inputFieldNames = inputRowType.getFieldNames();
        RexProgramBuilder programBuilder = new RexProgramBuilder(inputRowType, relBuilder.getRexBuilder());
        for (int fieldIndex : usedFields) {
            programBuilder.addProject(programBuilder.makeInputRef(fieldIndex), inputFieldNames.get(fieldIndex));
        }
        RexNode condition = relBuilder.and(conditions);
        if (!condition.isAlwaysTrue()) {
            programBuilder.addCondition(condition);
        }
        RexProgram newProgram = programBuilder.getProgram();
        return new StreamPhysicalCalc(input.getCluster(), input.getTraitSet(), input, newProgram, newProgram.getOutputRowType());
    }

    private void transformWithRemainingPredicates(RelOptRuleCall call, StreamPhysicalChangelogNormalize changelogNormalize, int[] usedInputFields) {
        StreamPhysicalCalc calc = (StreamPhysicalCalc)call.rel(0);
        RelBuilder relBuilder = call.builder();
        RexProgramBuilder programBuilder = new RexProgramBuilder(changelogNormalize.getRowType(), relBuilder.getRexBuilder());
        Map<Integer, Integer> inputRefMapping = this.buildFieldsMapping(usedInputFields);
        for (Pair<RexLocalRef, String> ref : calc.getProgram().getNamedProjects()) {
            RexNode shiftedProject = this.adjustInputRef(calc.getProgram().expandLocalRef((RexLocalRef)ref.left), inputRefMapping);
            programBuilder.addProject(shiftedProject, (String)ref.right);
        }
        RexProgram newProgram = programBuilder.getProgram();
        if (newProgram.isTrivial()) {
            call.transformTo(changelogNormalize);
        } else {
            StreamPhysicalCalc newProjectedCalc = new StreamPhysicalCalc(changelogNormalize.getCluster(), changelogNormalize.getTraitSet(), (RelNode)changelogNormalize, newProgram, newProgram.getOutputRowType());
            call.transformTo(newProjectedCalc);
        }
    }

    private RexNode adjustInputRef(RexNode expr, final Map<Integer, Integer> mapping) {
        return expr.accept(new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef inputRef) {
                Integer newIndex = (Integer)mapping.get(inputRef.getIndex());
                return new RexInputRef(newIndex, inputRef.getType());
            }
        });
    }

    private Map<Integer, Integer> buildFieldsMapping(int[] projectedInputRefs) {
        HashMap<Integer, Integer> fieldsOldToNewIndexMapping = new HashMap<Integer, Integer>();
        for (int i = 0; i < projectedInputRefs.length; ++i) {
            fieldsOldToNewIndexMapping.put(projectedInputRefs[i], i);
        }
        return fieldsOldToNewIndexMapping;
    }

    @Value.Immutable(singleton=false)
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutablePushCalcPastChangelogNormalizeRule.Config.builder().build().onMatch();

        @Override
        default public RelOptRule toRule() {
            return new PushCalcPastChangelogNormalizeRule(this);
        }

        default public Config onMatch() {
            RelRule.OperandTransform exchangeTransform = operandBuilder -> operandBuilder.operand(StreamPhysicalExchange.class).anyInputs();
            RelRule.OperandTransform changelogNormalizeTransform = operandBuilder -> operandBuilder.operand(StreamPhysicalChangelogNormalize.class).oneInput(exchangeTransform);
            RelRule.OperandTransform calcTransform = operandBuilder -> operandBuilder.operand(StreamPhysicalCalc.class).oneInput(changelogNormalizeTransform);
            return this.withOperandSupplier(calcTransform).as(Config.class);
        }
    }
}

