/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.cost.FilterStatsCalculator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SemiJoinStatsCalculator;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.StatsProvider;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class SimpleFilterProjectSemiJoinStatsRule
extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final Metadata metadata;
    private final FilterStatsCalculator filterStatsCalculator;

    public SimpleFilterProjectSemiJoinStatsRule(Metadata metadata, StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator) {
        super(normalizer);
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator cannot be null");
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        SemiJoinNode semiJoinNode;
        PlanNode nodeSource = lookup.resolve(node.getSource());
        if (nodeSource instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode)nodeSource;
            if (!projectNode.isIdentity()) {
                return Optional.empty();
            }
            PlanNode projectNodeSource = lookup.resolve(projectNode.getSource());
            if (!(projectNodeSource instanceof SemiJoinNode)) {
                return Optional.empty();
            }
            semiJoinNode = (SemiJoinNode)projectNodeSource;
        } else if (nodeSource instanceof SemiJoinNode) {
            semiJoinNode = (SemiJoinNode)nodeSource;
        } else {
            return Optional.empty();
        }
        return this.calculate(node, semiJoinNode, sourceStats, session, types);
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider types) {
        PlanNodeStatsEstimate sourceStats = statsProvider.getStats(semiJoinNode.getSource());
        PlanNodeStatsEstimate filteringSourceStats = statsProvider.getStats(semiJoinNode.getFilteringSource());
        Symbol filteringSourceJoinSymbol = semiJoinNode.getFilteringSourceJoinSymbol();
        Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol();
        Optional<SemiJoinOutputFilter> semiJoinOutputFilter = this.extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinNode.getSemiJoinOutput());
        if (semiJoinOutputFilter.isEmpty()) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate semiJoinStats = semiJoinOutputFilter.get().isNegated() ? SemiJoinStatsCalculator.computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol) : SemiJoinStatsCalculator.computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol);
        if (semiJoinStats.isOutputRowCountUnknown()) {
            return Optional.of(PlanNodeStatsEstimate.unknown());
        }
        PlanNodeStatsEstimate filteredStats = this.filterStatsCalculator.filterStats(semiJoinStats, semiJoinOutputFilter.get().getRemainingPredicate(), session, types);
        if (filteredStats.isOutputRowCountUnknown()) {
            return Optional.of(semiJoinStats.mapOutputRowCount(rowCount -> rowCount * 0.9));
        }
        return Optional.of(filteredStats);
    }

    private Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression predicate, Symbol semiJoinOutput) {
        List<Expression> conjuncts = ExpressionUtils.extractConjuncts(predicate);
        List semiJoinOutputReferences = (List)conjuncts.stream().filter(conjunct -> SimpleFilterProjectSemiJoinStatsRule.isSemiJoinOutputReference(conjunct, semiJoinOutput)).collect(ImmutableList.toImmutableList());
        if (semiJoinOutputReferences.size() != 1) {
            return Optional.empty();
        }
        Expression semiJoinOutputReference = (Expression)Iterables.getOnlyElement((Iterable)semiJoinOutputReferences);
        Expression remainingPredicate = ExpressionUtils.combineConjuncts(this.metadata, (Collection)conjuncts.stream().filter(conjunct -> conjunct != semiJoinOutputReference).collect(ImmutableList.toImmutableList()));
        boolean negated = semiJoinOutputReference instanceof NotExpression;
        return Optional.of(new SemiJoinOutputFilter(negated, remainingPredicate));
    }

    private static boolean isSemiJoinOutputReference(Expression conjunct, Symbol semiJoinOutput) {
        SymbolReference semiJoinOuputSymbolReference = semiJoinOutput.toSymbolReference();
        return conjunct.equals((Object)semiJoinOuputSymbolReference) || conjunct instanceof NotExpression && ((NotExpression)conjunct).getValue().equals((Object)semiJoinOuputSymbolReference);
    }

    private static class SemiJoinOutputFilter {
        private final boolean negated;
        private final Expression remainingPredicate;

        public SemiJoinOutputFilter(boolean negated, Expression remainingPredicate) {
            this.negated = negated;
            this.remainingPredicate = Objects.requireNonNull(remainingPredicate, "remainingPredicate cannot be null");
        }

        public boolean isNegated() {
            return this.negated;
        }

        public Expression getRemainingPredicate() {
            return this.remainingPredicate;
        }
    }
}

