Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6432] Infinite loop for JoinPushTransitivePredicatesRule #3819

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 137 additions & 123 deletions core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,12 @@ public RelOptPredicateList getPredicates(Project project,
final List<RexNode> projectPullUpPredicates = new ArrayList<>();

ImmutableBitSet.Builder columnsMappedBuilder = ImmutableBitSet.builder();
Mapping m =
Mappings.create(MappingType.PARTIAL_FUNCTION,
input.getRowType().getFieldCount(),
project.getRowType().getFieldCount());
Map<Integer, BitSet> equivalence = new HashMap<>();

for (Ord<RexNode> expr : Ord.zip(project.getProjects())) {
if (expr.e instanceof RexInputRef) {
int sIdx = ((RexInputRef) expr.e).getIndex();
m.set(sIdx, expr.i);
equivalence.computeIfAbsent(sIdx, k -> new BitSet()).set(expr.i);
columnsMappedBuilder.set(sIdx);
} else if (RexUtil.isConstant(expr.e)) {
// Project can also generate constants (including NULL). We need to
Expand All @@ -216,8 +213,19 @@ public RelOptPredicateList getPredicates(Project project,
for (RexNode r : inputInfo.pulledUpPredicates) {
RexNode r2 = projectPredicate(rexBuilder, input, r, columnsMapped);
if (!r2.isAlwaysTrue()) {
r2 = r2.accept(new RexPermuteInputsShuttle(m, input));
projectPullUpPredicates.add(r2);
ImmutableBitSet fields = RelOptUtil.InputFinder.bits(r2);
if (fields.cardinality() == 0) {
projectPullUpPredicates.add(r2);
continue;
}
ExprsItr exprsItr =
new ExprsItr(fields, equivalence, input.getRowType().getFieldCount(),
project.getRowType().getFieldCount());
while (exprsItr.hasNext()) {
Mapping m = exprsItr.next();
RexNode r3 = r2.accept(new RexPermuteInputsShuttle(m, input));
projectPullUpPredicates.add(r3);
}
}
}
return RelOptPredicateList.of(rexBuilder, projectPullUpPredicates);
Expand Down Expand Up @@ -874,7 +882,9 @@ Iterable<Mapping> mappings(final RexNode predicate) {
if (fields.cardinality() == 0) {
return Collections.emptyList();
}
return () -> new ExprsItr(fields);
return () -> new ExprsItr(fields, equivalence,
nSysFields + nFieldsLeft + nFieldsRight,
nSysFields + nFieldsLeft + nFieldsRight);
}

private static boolean checkTarget(ImmutableBitSet inferringFields,
Expand Down Expand Up @@ -918,139 +928,143 @@ protected EquivalenceFinder() {
}
}

/**
* Given an expression returns all the possible substitutions.
*
* <p>For example, for an expression 'a + b + c' and the following
* equivalences: <pre>
* a : {a, b}
* b : {a, b}
* c : {c, e}
* </pre>
*
* <p>The following Mappings will be returned:
* <pre>
* {a &rarr; a, b &rarr; a, c &rarr; c}
* {a &rarr; a, b &rarr; a, c &rarr; e}
* {a &rarr; a, b &rarr; b, c &rarr; c}
* {a &rarr; a, b &rarr; b, c &rarr; e}
* {a &rarr; b, b &rarr; a, c &rarr; c}
* {a &rarr; b, b &rarr; a, c &rarr; e}
* {a &rarr; b, b &rarr; b, c &rarr; c}
* {a &rarr; b, b &rarr; b, c &rarr; e}
* </pre>
*
* <p>which imply the following inferences:
* <pre>
* a + a + c
* a + a + e
* a + b + c
* a + b + e
* b + a + c
* b + a + e
* b + b + c
* b + b + e
* </pre>
*/
class ExprsItr implements Iterator<Mapping> {
final int[] columns;
final BitSet[] columnSets;
final int[] iterationIdx;
@Nullable Mapping nextMapping;
boolean firstCall;

@SuppressWarnings("JdkObsolete")
ExprsItr(ImmutableBitSet fields) {
nextMapping = null;
columns = new int[fields.cardinality()];
columnSets = new BitSet[fields.cardinality()];
iterationIdx = new int[fields.cardinality()];
for (int j = 0, i = fields.nextSetBit(0); i >= 0; i = fields
.nextSetBit(i + 1), j++) {
columns[j] = i;
int fieldIndex = i;
columnSets[j] =
requireNonNull(equivalence.get(i),
() -> "equivalence.get(i) is null for " + fieldIndex
+ ", " + equivalence);
iterationIdx[j] = 0;
}
firstCall = true;
private static int pos(RexNode expr) {
if (expr instanceof RexInputRef) {
return ((RexInputRef) expr).getIndex();
}
return -1;
}

@Override public boolean hasNext() {
if (firstCall) {
initializeMapping();
firstCall = false;
} else {
computeNextMapping(iterationIdx.length - 1);
private static boolean isAlwaysTrue(RexNode predicate) {
if (predicate instanceof RexCall) {
RexCall c = (RexCall) predicate;
if (c.getOperator().getKind() == SqlKind.EQUALS) {
int lPos = pos(c.getOperands().get(0));
int rPos = pos(c.getOperands().get(1));
return lPos != -1 && lPos == rPos;
}
return nextMapping != null;
}
return predicate.isAlwaysTrue();
}
}

@Override public Mapping next() {
if (nextMapping == null) {
throw new NoSuchElementException();
}
return nextMapping;
}
/**
* Given an expression returns all the possible substitutions.
*
* <p>For example, for an expression 'a + b + c' and the following
* equivalences: <pre>
* a : {a, b}
* b : {a, b}
* c : {c, e}
* </pre>
*
* <p>The following Mappings will be returned:
* <pre>
* {a &rarr; a, b &rarr; a, c &rarr; c}
* {a &rarr; a, b &rarr; a, c &rarr; e}
* {a &rarr; a, b &rarr; b, c &rarr; c}
* {a &rarr; a, b &rarr; b, c &rarr; e}
* {a &rarr; b, b &rarr; a, c &rarr; c}
* {a &rarr; b, b &rarr; a, c &rarr; e}
* {a &rarr; b, b &rarr; b, c &rarr; c}
* {a &rarr; b, b &rarr; b, c &rarr; e}
* </pre>
*
* <p>which imply the following inferences:
* <pre>
* a + a + c
* a + a + e
* a + b + c
* a + b + e
* b + a + c
* b + a + e
* b + b + c
* b + b + e
* </pre>
*/
static class ExprsItr implements Iterator<Mapping> {
final int[] columns;
final BitSet[] columnSets;
final int[] iterationIdx;
@Nullable Mapping nextMapping;
boolean firstCall;
int sourceCount;
int targetCount;

@Override public void remove() {
throw new UnsupportedOperationException();
@SuppressWarnings("JdkObsolete")
ExprsItr(ImmutableBitSet fields, Map<Integer, BitSet> equivalence,
int sourceCount, int targetCount) {
nextMapping = null;
columns = new int[fields.cardinality()];
columnSets = new BitSet[fields.cardinality()];
iterationIdx = new int[fields.cardinality()];
for (int j = 0, i = fields.nextSetBit(0); i >= 0; i = fields
.nextSetBit(i + 1), j++) {
columns[j] = i;
int fieldIndex = i;
columnSets[j] =
requireNonNull(equivalence.get(i),
() -> "equivalence.get(i) is null for " + fieldIndex
+ ", " + equivalence);
iterationIdx[j] = 0;
}
firstCall = true;
this.sourceCount = sourceCount;
this.targetCount = targetCount;
}

private void computeNextMapping(int level) {
int t = columnSets[level].nextSetBit(iterationIdx[level]);
if (t < 0) {
if (level == 0) {
nextMapping = null;
} else {
int tmp = columnSets[level].nextSetBit(0);
requireNonNull(nextMapping, "nextMapping").set(columns[level], tmp);
iterationIdx[level] = tmp + 1;
computeNextMapping(level - 1);
}
} else {
requireNonNull(nextMapping, "nextMapping").set(columns[level], t);
iterationIdx[level] = t + 1;
}
@Override public boolean hasNext() {
if (firstCall) {
initializeMapping();
firstCall = false;
} else {
computeNextMapping(iterationIdx.length - 1);
}
return nextMapping != null;
}

private void initializeMapping() {
nextMapping =
Mappings.create(MappingType.PARTIAL_FUNCTION,
nSysFields + nFieldsLeft + nFieldsRight,
nSysFields + nFieldsLeft + nFieldsRight);
for (int i = 0; i < columnSets.length; i++) {
BitSet c = columnSets[i];
int t = c.nextSetBit(iterationIdx[i]);
if (t < 0) {
nextMapping = null;
return;
}
nextMapping.set(columns[i], t);
iterationIdx[i] = t + 1;
}
@Override public Mapping next() {
if (nextMapping == null) {
throw new NoSuchElementException();
}
return nextMapping;
}

private static int pos(RexNode expr) {
if (expr instanceof RexInputRef) {
return ((RexInputRef) expr).getIndex();
@Override public void remove() {
throw new UnsupportedOperationException();
}

private void computeNextMapping(int level) {
int t = columnSets[level].nextSetBit(iterationIdx[level]);
if (t < 0) {
if (level == 0) {
nextMapping = null;
} else {
int tmp = columnSets[level].nextSetBit(0);
requireNonNull(nextMapping, "nextMapping").set(columns[level], tmp);
iterationIdx[level] = tmp + 1;
computeNextMapping(level - 1);
}
} else {
requireNonNull(nextMapping, "nextMapping").set(columns[level], t);
iterationIdx[level] = t + 1;
}
return -1;
}

private static boolean isAlwaysTrue(RexNode predicate) {
if (predicate instanceof RexCall) {
RexCall c = (RexCall) predicate;
if (c.getOperator().getKind() == SqlKind.EQUALS) {
int lPos = pos(c.getOperands().get(0));
int rPos = pos(c.getOperands().get(1));
return lPos != -1 && lPos == rPos;
private void initializeMapping() {
nextMapping =
Mappings.create(MappingType.PARTIAL_FUNCTION,
sourceCount, targetCount);
for (int i = 0; i < columnSets.length; i++) {
BitSet c = columnSets[i];
int t = c.nextSetBit(iterationIdx[i]);
if (t < 0) {
nextMapping = null;
return;
}
nextMapping.set(columns[i], t);
iterationIdx[i] = t + 1;
}
return predicate.isAlwaysTrue();
}
}
}
20 changes: 19 additions & 1 deletion core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,24 @@ private static boolean skipItem(RexNode expr) {
.check();
}

/**
* Test case for
* <a href="https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-6432">
* [CALCITE-6432] Infinite loop for JoinPushTransitivePredicatesRule
* when there are multiple project expressions reference the same input field</a>. */
@Test void testProjectPredicatePull() {
final String sql = "select e.ename, d.dname\n"
+ "from (select ename, deptno from emp where deptno = 10) e\n"
+ "join (select name dname, deptno, * from dept) d\n"
+ "on e.deptno = d.deptno";
final HepProgram program = new HepProgramBuilder()
.addRuleCollection(
ImmutableList.of(CoreRules.FILTER_PROJECT_TRANSPOSE,
CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES))
.build();
sql(sql).withProgram(program).check();
}

/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5971">[CALCITE-5971]
Expand Down Expand Up @@ -6114,7 +6132,7 @@ private HepProgram getTransitiveProgram() {
.withRule(CoreRules.FILTER_INTO_JOIN,
CoreRules.JOIN_CONDITION_PUSH,
CoreRules.JOIN_PUSH_TRANSITIVE_PREDICATES)
.check();
.checkUnchanged();
}

/** Test case for
Expand Down
Loading
Loading