/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexSetSemanticsTableCall;
import org.apache.flink.table.planner.functions.sql.SqlWindowTableFunction;
import org.apache.flink.table.planner.plan.logical.SessionWindowSpec;
import org.apache.flink.table.planner.plan.logical.TimeAttributeWindowingStrategy;
import org.apache.flink.table.planner.plan.utils.WindowUtil;
import org.apache.flink.table.types.logical.LogicalType;

public class ProjectWindowTableFunctionTransposeRule
extends RelOptRule {
    public static final ProjectWindowTableFunctionTransposeRule INSTANCE = new ProjectWindowTableFunctionTransposeRule();

    public ProjectWindowTableFunctionTransposeRule() {
        super(ProjectWindowTableFunctionTransposeRule.operand(LogicalProject.class, ProjectWindowTableFunctionTransposeRule.operand(LogicalTableFunctionScan.class, ProjectWindowTableFunctionTransposeRule.any()), new RelOptRuleOperand[0]), "ProjectWindowTableFunctionTransposeRule");
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        LogicalTableFunctionScan scan = (LogicalTableFunctionScan)call.rel(1);
        return WindowUtil.isWindowTableFunctionCall(scan.getCall());
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        LogicalProject project = (LogicalProject)call.rel(0);
        LogicalTableFunctionScan scan = (LogicalTableFunctionScan)call.rel(1);
        RelNode scanInput = scan.getInput(0);
        TimeAttributeWindowingStrategy windowingStrategy = WindowUtil.convertToWindowingStrategy((RexCall)scan.getCall(), scanInput);
        ImmutableBitSet projectFields = RelOptUtil.InputFinder.bits(project.getProjects(), null);
        int scanInputFieldCount = scanInput.getRowType().getFieldCount();
        ImmutableBitSet toPushFields = ImmutableBitSet.range(0, scanInputFieldCount).intersect(projectFields).set(windowingStrategy.getTimeAttributeIndex());
        if (windowingStrategy.getWindow() instanceof SessionWindowSpec) {
            SessionWindowSpec sessionWindowSpec = (SessionWindowSpec)windowingStrategy.getWindow();
            int[] partitionKeyIndices = sessionWindowSpec.getPartitionKeyIndices();
            toPushFields = toPushFields.union(ImmutableBitSet.of(partitionKeyIndices));
        }
        if (toPushFields.cardinality() == scanInputFieldCount) {
            return;
        }
        RelBuilder relBuilder = call.builder();
        RelNode newScanInput = this.createInnerProject(relBuilder, scanInput, toPushFields);
        Map<Integer, Integer> mapping = this.getFieldMapping(scan.getRowType().getFieldCount(), scanInputFieldCount, toPushFields);
        LogicalTableFunctionScan newScan = this.createNewTableFunctionScan(relBuilder, scan, windowingStrategy.getTimeAttributeType(), newScanInput, mapping);
        RelNode topProject = this.createTopProject(relBuilder, project, newScan, mapping);
        call.transformTo(topProject);
    }

    private Map<Integer, Integer> getFieldMapping(int scanFieldCount, int scanInputFieldCount, ImmutableBitSet toPushFields) {
        int toPushFieldCount = toPushFields.cardinality();
        HashMap<Integer, Integer> mapping = new HashMap<Integer, Integer>();
        IntStream.range(0, scanFieldCount).forEach(idx -> {
            int newPosition = idx < scanInputFieldCount ? toPushFields.indexOf(idx) : toPushFieldCount + idx - scanInputFieldCount;
            mapping.put(idx, newPosition);
        });
        return mapping;
    }

    private RelNode createInnerProject(RelBuilder relBuilder, RelNode scanInput, ImmutableBitSet toPushFields) {
        relBuilder.push(scanInput);
        List newProjects = toPushFields.toList().stream().map(relBuilder::field).collect(Collectors.toList());
        return relBuilder.project(newProjects).build();
    }

    private LogicalTableFunctionScan createNewTableFunctionScan(RelBuilder relBuilder, LogicalTableFunctionScan oldScan, LogicalType timeAttributeType, RelNode newInput, Map<Integer, Integer> mapping) {
        relBuilder.push(newInput);
        RexNode newCall = this.rewriteWindowCall((RexCall)oldScan.getCall(), mapping, relBuilder);
        RelOptCluster cluster = oldScan.getCluster();
        FlinkTypeFactory typeFactory = (FlinkTypeFactory)cluster.getTypeFactory();
        RelDataType newScanOutputType = SqlWindowTableFunction.inferRowType(typeFactory, newInput.getRowType(), typeFactory.createFieldTypeFromLogicalType(timeAttributeType));
        return LogicalTableFunctionScan.create(cluster, new ArrayList<RelNode>(Collections.singleton(newInput)), newCall, oldScan.getElementType(), newScanOutputType, oldScan.getColumnMappings());
    }

    private RexNode rewriteWindowCall(RexCall windowCall, Map<Integer, Integer> mapping, RelBuilder relBuilder) {
        ArrayList<RexNode> newOperands = new ArrayList<RexNode>();
        for (RexNode next : windowCall.getOperands()) {
            newOperands.add(this.adjustInputRef(next, mapping));
        }
        if (windowCall instanceof RexSetSemanticsTableCall) {
            RexSetSemanticsTableCall originalCall = (RexSetSemanticsTableCall)windowCall;
            int[] newPartitionKeys = Arrays.stream(originalCall.getPartitionKeys()).map(mapping::get).toArray();
            int[] newOrderKeys = Arrays.stream(originalCall.getOrderKeys()).map(mapping::get).toArray();
            return originalCall.copy(newOperands, newPartitionKeys, newOrderKeys);
        }
        return relBuilder.call(windowCall.getOperator(), (Iterable<? extends RexNode>)newOperands);
    }

    private RelNode createTopProject(RelBuilder relBuilder, LogicalProject oldProject, LogicalTableFunctionScan newInput, Map<Integer, Integer> mapping) {
        List newTopProjects = oldProject.getNamedProjects().stream().map(r -> Pair.of(this.adjustInputRef((RexNode)r.left, mapping), (String)r.right)).collect(Collectors.toList());
        return relBuilder.push(newInput).project(Pair.left(newTopProjects), Pair.right(newTopProjects)).build();
    }

    private RexNode adjustInputRef(RexNode expr, final Map<Integer, Integer> mapping) {
        return expr.accept(new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef inputRef) {
                Integer newIndex = (Integer)mapping.get(inputRef.getIndex());
                return new RexInputRef(newIndex, inputRef.getType());
            }
        });
    }
}

