/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.querykit.common;

import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.Consumer;
import javax.annotation.Nullable;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannel;
import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.OutputChannels;
import org.apache.druid.frame.processor.manager.ProcessorManagers;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.msq.counters.CounterTracker;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSliceReader;
import org.apache.druid.msq.input.InputSlices;
import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.FrameContext;
import org.apache.druid.msq.kernel.ProcessorsAndChannels;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.querykit.BaseFrameProcessorFactory;
import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.Equality;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;

@JsonTypeName(value="sortMergeJoin")
public class SortMergeJoinFrameProcessorFactory
extends BaseFrameProcessorFactory {
    private static final int LEFT = 0;
    private static final int RIGHT = 1;
    private final String rightPrefix;
    private final JoinConditionAnalysis condition;
    private final JoinType joinType;

    public SortMergeJoinFrameProcessorFactory(String rightPrefix, JoinConditionAnalysis condition, JoinType joinType) {
        this.rightPrefix = (String)Preconditions.checkNotNull((Object)rightPrefix, (Object)"rightPrefix");
        this.condition = SortMergeJoinFrameProcessorFactory.validateCondition((JoinConditionAnalysis)Preconditions.checkNotNull((Object)condition, (Object)"condition"));
        this.joinType = (JoinType)Preconditions.checkNotNull((Object)joinType, (Object)"joinType");
    }

    @JsonCreator
    public static SortMergeJoinFrameProcessorFactory create(@JsonProperty(value="rightPrefix") String rightPrefix, @JsonProperty(value="condition") String condition, @JsonProperty(value="joinType") JoinType joinType, @JacksonInject ExprMacroTable macroTable) {
        return new SortMergeJoinFrameProcessorFactory(StringUtils.nullToEmptyNonDruidDataString((String)rightPrefix), JoinConditionAnalysis.forExpression((String)((String)Preconditions.checkNotNull((Object)condition, (Object)"condition")), (String)StringUtils.nullToEmptyNonDruidDataString((String)rightPrefix), (ExprMacroTable)macroTable), joinType);
    }

    @JsonProperty
    public String getRightPrefix() {
        return this.rightPrefix;
    }

    @JsonProperty
    public String getCondition() {
        return this.condition.getOriginalExpression();
    }

    @JsonProperty
    public JoinType getJoinType() {
        return this.joinType;
    }

    @Override
    public ProcessorsAndChannels<Object, Long> makeProcessors(StageDefinition stageDefinition, int workerNumber, List<InputSlice> inputSlices, InputSliceReader inputSliceReader, @Nullable Object extra, OutputChannelFactory outputChannelFactory, FrameContext frameContext, int maxOutstandingProcessors, CounterTracker counters, Consumer<Throwable> warningPublisher, boolean removeNullBytes) throws IOException {
        if (inputSlices.size() != 2 || !inputSlices.stream().allMatch(slice -> slice instanceof StageInputSlice)) {
            throw new ISE("Expected two stage inputs", new Object[0]);
        }
        List<List<KeyColumn>> keyColumns = SortMergeJoinFrameProcessorFactory.toKeyColumns(this.condition);
        int[] requiredNonNullKeyParts = SortMergeJoinFrameProcessorFactory.toRequiredNonNullKeyParts(this.condition);
        Int2ObjectMap<List<ReadableInput>> inputsByPartition = SortMergeJoinFrameProcessorFactory.validateInputFrameSignatures(InputSlices.attachAndCollectPartitions(inputSlices, inputSliceReader, counters, warningPublisher), keyColumns);
        if (inputsByPartition.isEmpty()) {
            return new ProcessorsAndChannels<Object, Long>(ProcessorManagers.none(), OutputChannels.none());
        }
        Int2ObjectAVLTreeMap outputChannels = new Int2ObjectAVLTreeMap();
        IntIterator intIterator = inputsByPartition.keySet().iterator();
        while (intIterator.hasNext()) {
            int partitionNumber = (Integer)intIterator.next();
            outputChannels.put(partitionNumber, (Object)outputChannelFactory.openChannel(partitionNumber));
        }
        Iterable processors = Iterables.transform((Iterable)inputsByPartition.int2ObjectEntrySet(), arg_0 -> this.lambda$makeProcessors$1((Int2ObjectMap)outputChannels, stageDefinition, removeNullBytes, keyColumns, requiredNonNullKeyParts, frameContext, arg_0));
        return new ProcessorsAndChannels<Object, Long>(ProcessorManagers.of((Iterable)processors), OutputChannels.wrap((List)ImmutableList.copyOf((Collection)outputChannels.values())));
    }

    @Override
    public boolean usesProcessingBuffers() {
        return false;
    }

    public static List<List<KeyColumn>> toKeyColumns(JoinConditionAnalysis condition) {
        ArrayList<List<KeyColumn>> retVal = new ArrayList<List<KeyColumn>>();
        retVal.add(new ArrayList());
        retVal.add(new ArrayList());
        for (Equality equiCondition : condition.getEquiConditions()) {
            String leftColumn = (String)Preconditions.checkNotNull((Object)equiCondition.getLeftExpr().getBindingIfIdentifier(), (Object)"leftExpr#getBindingIfIdentifier");
            ((List)retVal.get(0)).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING));
            ((List)retVal.get(1)).add(new KeyColumn(equiCondition.getRightColumn(), KeyOrder.ASCENDING));
        }
        return retVal;
    }

    public static int[] toRequiredNonNullKeyParts(JoinConditionAnalysis condition) {
        IntArrayList retVal = new IntArrayList(condition.getEquiConditions().size());
        List equiConditions = condition.getEquiConditions();
        for (int i = 0; i < equiConditions.size(); ++i) {
            Equality equiCondition = (Equality)equiConditions.get(i);
            if (equiCondition.isIncludeNull()) continue;
            retVal.add(i);
        }
        return retVal.toArray(new int[0]);
    }

    public static JoinConditionAnalysis validateCondition(JoinConditionAnalysis condition) {
        if (condition.isAlwaysTrue()) {
            return condition;
        }
        if (condition.isAlwaysFalse()) {
            throw new IAE("Cannot handle constant condition: %s", new Object[]{condition.getOriginalExpression()});
        }
        if (condition.getNonEquiConditions().size() > 0) {
            throw new IAE("Cannot handle non-equijoin condition: %s", new Object[]{condition.getOriginalExpression()});
        }
        if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier())) {
            throw new IAE("Cannot handle equality condition involving left-hand expression: %s", new Object[]{condition.getOriginalExpression()});
        }
        return condition;
    }

    private static Int2ObjectMap<List<ReadableInput>> validateInputFrameSignatures(Int2ObjectMap<List<ReadableInput>> inputsByPartition, List<List<KeyColumn>> keyColumns) {
        for (List readableInputs : inputsByPartition.values()) {
            for (int i = 0; i < readableInputs.size(); ++i) {
                ReadableInput readableInput = (ReadableInput)readableInputs.get(i);
                Preconditions.checkState((boolean)readableInput.hasChannel(), (String)"readableInput[%s].hasChannel", (int)i);
                RowSignature signature = readableInput.getChannelFrameReader().signature();
                for (int j = 0; j < keyColumns.get(i).size(); ++j) {
                    String columnName = keyColumns.get(i).get(j).columnName();
                    Preconditions.checkState((boolean)columnName.equals(signature.getColumnName(j)), (String)"readableInput[%s] column[%s] has expected name[%s]", (Object)i, (Object)j, (Object)columnName);
                }
            }
        }
        return inputsByPartition;
    }

    private /* synthetic */ FrameProcessor lambda$makeProcessors$1(Int2ObjectMap outputChannels, StageDefinition stageDefinition, boolean removeNullBytes, List keyColumns, int[] requiredNonNullKeyParts, FrameContext frameContext, Int2ObjectMap.Entry entry) {
        int partitionNumber = entry.getIntKey();
        List readableInputs = (List)entry.getValue();
        OutputChannel outputChannel = (OutputChannel)outputChannels.get(partitionNumber);
        return new SortMergeJoinFrameProcessor((ReadableInput)readableInputs.get(0), (ReadableInput)readableInputs.get(1), outputChannel.getWritableChannel(), stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator(), removeNullBytes), this.rightPrefix, keyColumns, requiredNonNullKeyParts, this.joinType, frameContext.memoryParameters().getSortMergeJoinMemory());
    }
}

