diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java index 979739ae83f..d53a2b81c2e 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java @@ -338,8 +338,91 @@ public static void setLeftRightBitmaps( double res = 0; if (dSize > 0) { - double expo = numSel * Math.log(1.0 - 1.0 / dSize); - res = (1.0 - Math.exp(expo)) * dSize; + // We calculate the result in 3 cases: + // 1) If N and n are not very large values, we directly calculate + // N * [1 - exp(n * ln(1 - 1 / N))] + // 2) If N or n are very large values, but n / N is not close to 0, + // We compute the result by expanding the exponent n * ln(1 - 1 / N) + // as a Taylor series. + // 3) If N or n are very large values, and n / N is close to 0, + // we expand the whole formula as a Taylor series. + // We separate the 3 cases by trying to expand the exponent as a Taylor series. + // If the required number of terms for convergence is too big, it is case 1. + // If the required number of terms is too small, it is case 3. Otherwise, + // it is case 2. + + // To expand the component as a Taylor series, we have: + // n * ln(1 - 1 / N) + // = n * (- 1 / N - 1 / (2 * N ^ 2) - 1 / (3 * N ^ 3) - ...) + // + // To get an approximate result, we truncate the series after the first m terms. + // This leads to an error of: + // n * (1 / [(m + 1) * N ^ (m + 1)] + 1 / [(m + 2) * dSize ^ (m + 2) + ...]) + // < n / (m + 1) / [N ^ m * (N - 1)] < n / (N ^ m) + // + // To get an accurate result, the error should be smaller than 1e-16, because a + // double can represent at most 15 digits after the decimal point. Therefore, we need + // n / (N ^ m ) < 1e-16, or N ^ m / n > 1e16. + // + // The smallest value of m that satisfies the above formula is the + // number of terms required for convergence. + + final int maxTerms = 32; + final int minTerms = 1; + + int numTerms = 1; + if (dSize > 10) { + double dPower = dSize; + while (dPower / numSel <= 1e16) { + dPower = dPower * dPower; + numTerms *= 2; + } + } else { + // for small dSize, force case 1 + numTerms = maxTerms + 1; + } + + if (numTerms > maxTerms) { + // case 1 + double expo = numSel * Math.log(1.0 - 1.0 / dSize); + res = (1.0 - Math.exp(expo)) * dSize; + } else if (numTerms > minTerms) { + // case 2 + double expo = 0; + double dPower = dSize; + for (int i = 1; i <= numTerms; i++) { + expo -= numSel / dPower / i; + dPower *= dSize; + } + res = (1.0 - Math.exp(expo)) * dSize; + } else { + // numTerms <= minTerms + // case 3 + + // Since the ratio n / N is close to 0, we can expand the exponent as + // n * ln(1 - 1 / N) ≈ n * (- 1 / N) = - n / N. + // So N * [1 - (1 - 1 / N) ^ n ] = N * [1 - exp(n * ln(1 - 1 / N))] + // ≈ N * [1 - exp(- n / N)]. + // Since exp(x) = 1 + x + x ^ 2 / 2 + x ^ 3 / 6 + ..., we have + // N * [1 - exp(- n / N)] = N * [n / N - n ^ 2 / (2 * N ^ 2) + n ^ 3 / (6 * N ^ 3) - ...] + // = n - n ^ 2 / (2 * N) + n ^ 3 / (6 * N ^ 2) - ... + // Since n / N is close to 0, and due to the factorial in the denominator, + // the above series converges to 0 very quickly. + res = 0; + numTerms = 5; + double selPower = numSel; + double domPower = 1; + int factorial = 1; + int sign = 1; + for (int i = 1; i < numTerms; i++) { + res += sign * selPower / factorial / domPower; + + selPower *= numSel; + domPower *= dSize; + factorial *= i + 1; + sign *= -1; + } + } } // fix the boundary cases diff --git a/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java b/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java index dcfa6ecb82f..8d1a8016d40 100644 --- a/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java +++ b/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java @@ -24,9 +24,13 @@ import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.apache.calcite.rel.metadata.RelMdUtil.numDistinctVals; +import static org.apache.calcite.test.Matchers.within; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.not; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; /** * Test cases for {@link RelMdUtil}. @@ -42,29 +46,55 @@ final RelMetadataFixture sql(String sql) { return fixture().withSql(sql); } + private static final double EPSILON = 1e-5; + @Test void testNumDistinctVals() { // the first element must be distinct, the second one has half chance of being distinct - assertEquals(1.5, RelMdUtil.numDistinctVals(2.0, 2.0), 1e-5); + assertThat(numDistinctVals(2.0, 2.0), within(1.5, EPSILON)); // when no selection is made, we get no distinct value double domainSize = 100; - assertEquals(0, RelMdUtil.numDistinctVals(domainSize, 0.0), 1e-5); + assertThat(numDistinctVals(domainSize, 0.0), within(0, EPSILON)); // when we perform one selection, we always have 1 distinct value, // regardless of the domain size for (double dSize = 1; dSize < 100; dSize += 1) { - assertEquals(1.0, RelMdUtil.numDistinctVals(dSize, 1.0), 1e-5); + assertThat(numDistinctVals(dSize, 1.0), within(1.0, EPSILON)); } // when we select n objects from a set with n values // we get no more than n distinct values for (double dSize = 1; dSize < 100; dSize += 1) { - assertTrue(RelMdUtil.numDistinctVals(dSize, dSize) <= dSize); + assertThat(numDistinctVals(dSize, dSize), lessThanOrEqualTo(dSize)); } // when the number of selections is large enough // we get all distinct values, w.h.p. - assertEquals(domainSize, RelMdUtil.numDistinctVals(domainSize, domainSize * 100), 1e-5); + assertThat(numDistinctVals(domainSize, domainSize * 100), within(domainSize, EPSILON)); + + assertThat(numDistinctVals(100.0, 2.0), within(1.99, EPSILON)); + assertThat(numDistinctVals(1000.0, 2.0), within(1.999, EPSILON)); + assertThat(numDistinctVals(10000.0, 2.0), within(1.9999, EPSILON)); + } + + @Test void testNumDistinctValsWithLargeDomain() { + double[] domainSizes = {1e18, 1e20}; + double[] numSels = {1e2, 1e4, 1e6, 1e8, 1e10, 1e12}; + double res; + for (double domainSize : domainSizes) { + for (double numSel : numSels) { + res = numDistinctVals(domainSize, numSel); + assertThat(res, not(0)); + // due to the possible duplicate selections, the distinct values + // must be smaller than or equal to the number of selections + assertThat(res, lessThanOrEqualTo(numSel)); + } + res = numDistinctVals(domainSize, 1.0); + assertThat(res, within(1.0, EPSILON)); + + res = numDistinctVals(domainSize, 2.0); + assertThat(res, within(2.0, EPSILON)); + } } @Test void testDynamicParameterInLimitOffset() {