Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] implement fast fisher exact test pvalue #14663

Draft
wants to merge 1 commit into
base: ps-08-14-add_rand_hyper_and_rand_multi_hyper_methods
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,8 @@ def exp(x) -> Float64Expression:
return _func("exp", tfloat64, x)


@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32)
def fisher_exact_test(c1, c2, c3, c4) -> StructExpression:
@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32, _pvalue_only=bool)
def fisher_exact_test(c1, c2, c3, c4, _pvalue_only=False) -> StructExpression:
"""Calculates the p-value, odds ratio, and 95% confidence interval using
Fisher's exact test for a 2x2 table.

Expand Down Expand Up @@ -1138,8 +1138,11 @@ def fisher_exact_test(c1, c2, c3, c4) -> StructExpression:
`ci_95_lower (:py:data:`.tfloat64`), and `ci_95_upper`
(:py:data:`.tfloat64`).
"""
ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64)
return _func("fisher_exact_test", ret_type, c1, c2, c3, c4)
if _pvalue_only:
return struct(p_value=_func("fisher_exact_test_pvalue_only", tfloat64, c1, c2, c3, c4))
else:
ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64)
return _func("fisher_exact_test", ret_type, c1, c2, c3, c4)


@typecheck(x=expr_oneof(expr_float32, expr_float64, expr_ndarray(expr_float64)))
Expand Down
28 changes: 25 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,15 @@ object MathFunctions extends RegistryFunctions {
fetStruct.virtualType,
(_, _, _, _, _) => fetStruct.sType,
) { case (r, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) =>
val res = cb.newLocal[Array[Double]](
"fisher_exact_test_res",
val res = cb.memoize[Array[Double]](
Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]](
statsPackageClass,
"fisherExactTest",
a.value,
b.value,
c.value,
d.value,
),
)
)

fetStruct.constructFromFields(
Expand All @@ -423,6 +422,29 @@ object MathFunctions extends RegistryFunctions {
)
}

// FIXME: delete when PruneDeadField can optimize fisher_exact_test when only
// the pvalue is used from the result struct
registerSCode4(
"fisher_exact_test_pvalue_only",
TInt32,
TInt32,
TInt32,
TInt32,
TFloat64,
(_, _, _, _, _) => SFloat64,
) { case (_, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) =>
primitive(cb.memoize[Double](
Code.invokeScalaObject4[Int, Int, Int, Int, Double](
statsPackageClass,
"fisherExactTestPValueOnly",
a.value,
b.value,
c.value,
d.value,
)
))
}

registerSCode4(
"chi_squared_test",
TInt32,
Expand Down
108 changes: 62 additions & 46 deletions hail/src/main/scala/is/hail/stats/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package is.hail

import is.hail.types.physical.{PCanonicalStruct, PFloat64}
import is.hail.utils._

import net.sourceforge.jdistlib.{Beta, ChiSquare, NonCentralChiSquare, Normal, Poisson}
import net.sourceforge.jdistlib.disttest.{DistributionTest, TestKind}
import org.apache.commons.math3.distribution.HypergeometricDistribution

import scala.annotation.tailrec

package object stats {

def uniroot(fn: Double => Double, min: Double, max: Double, tolerance: Double = 1.220703e-4)
Expand Down Expand Up @@ -162,45 +163,28 @@ package object stats {
)

def fisherExactTest(a: Int, b: Int, c: Int, d: Int): Array[Double] =
fisherExactTest(a, b, c, d, 1.0, 0.95, "two.sided")

def fisherExactTest(
a: Int,
b: Int,
c: Int,
d: Int,
oddsRatio: Double = 1d,
confidenceLevel: Double = 0.95,
alternative: String = "two.sided",
): Array[Double] = {
fisherExactTest(a, b, c, d, 0.95)

def fisherExactTest(a: Int, b: Int, c: Int, d: Int, confidenceLevel: Double): Array[Double] = {
if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0))
fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d")

if (confidenceLevel < 0d || confidenceLevel > 1d)
fatal("Confidence level must be between 0 and 1")

if (oddsRatio < 0d)
fatal("Odds ratio must be non-negative")

if (alternative != "greater" && alternative != "less" && alternative != "two.sided")
fatal("Did not recognize test type string. Use one of greater, less, two.sided")

val popSize = a + b + c + d
val numSuccessPopulation = a + c
val sampleSize = a + b
val nGood = a + c
val nSample = a + b
val numSuccessSample = a

if (
!(popSize > 0 && sampleSize > 0 && sampleSize < popSize && numSuccessPopulation > 0 && numSuccessPopulation < popSize)
)
if (!(popSize > 0 && nSample > 0 && nSample < popSize && nGood > 0 && nGood < popSize))
return Array(Double.NaN, Double.NaN, Double.NaN, Double.NaN)

val low = math.max(0, (a + b) - (b + d))
val high = math.min(a + b, a + c)
val support = (low to high).toArray

val hgd = new HypergeometricDistribution(null, popSize, numSuccessPopulation, sampleSize)
val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample)
val epsilon = 2.220446e-16

def dhyper(k: Int, logProb: Boolean): Double =
Expand Down Expand Up @@ -320,36 +304,68 @@ package object stats {
}
}

val pvalue: Double = (alternative: @unchecked) match {
case "less" => pnhyper(numSuccessSample, oddsRatio)
case "greater" => pnhyper(numSuccessSample, oddsRatio, upper_tail = true)
case "two.sided" =>
if (oddsRatio == 0)
if (low == numSuccessSample) 1d else 0d
else if (oddsRatio == Double.PositiveInfinity)
if (high == numSuccessSample) 1d else 0d
else {
val relErr = 1d + 1e-7
val d = dnhyper(oddsRatio)
d.filter(_ <= d(numSuccessSample - low) * relErr).sum
}
}

assert(pvalue >= 0d && pvalue <= 1.000000000002)
val pvalue = fisherExactTestPValueOnly(a, b, c, d)

val oddsRatioEstimate = mle(numSuccessSample)

val confInterval = alternative match {
case "less" => (0d, ncpUpper(numSuccessSample, 1 - confidenceLevel))
case "greater" => (ncpLower(numSuccessSample, 1 - confidenceLevel), Double.PositiveInfinity)
case "two.sided" =>
val alpha = (1 - confidenceLevel) / 2d
(ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha))
val confInterval = {
val alpha = (1 - confidenceLevel) / 2d
(ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha))
}

Array(pvalue, oddsRatioEstimate, confInterval._1, confInterval._2)
}

def fisherExactTestPValueOnly(a: Int, b: Int, c: Int, d: Int): Double = {
val popSize = a + b + c + d
val nGood = a + c
val nSample = a + b
val numSuccessSample = a

val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample)

// Returns i in [start, end] such that a([start, i)) is <= d, and a([i, end)) is > d
@tailrec def upperBoundIncreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = {
if (start >= end) return start
val mid = (start + end) >>> 1
val elt = a(mid)
if (elt <= d) upperBoundIncreasing(a, d, mid + 1, end)
else upperBoundIncreasing(a, d, start, mid)
}

// Returns i in [start, end] such that a([start, i)) is > d, and a([i, end)) is <= d
@tailrec def lowerBoundDecreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = {
if (start >= end) return start
val mid = (start + end) >>> 1
val elt = a(mid)
if (elt > d) lowerBoundDecreasing(a, d, mid + 1, end)
else lowerBoundDecreasing(a, d, start, mid)
}

val epsilon = 1e-14
val gamma = 1 + epsilon

val mode = ((nSample + 1.0) * (nGood + 1.0) / (popSize + 2.0)).toInt
val pexact = hgd.probability(numSuccessSample)
val pmode = hgd.probability(mode)

val pvalue = if (math.abs(pexact - pmode) / math.max(pexact, pmode) <= epsilon) {
1.0
} else if (numSuccessSample < mode) {
val plower = hgd.cumulativeProbability(numSuccessSample)
val bound = lowerBoundDecreasing(hgd.probability, pexact * gamma, mode + 1, nSample + 1)
plower + hgd.upperCumulativeProbability(bound)
} else {
val pupper = hgd.upperCumulativeProbability(numSuccessSample)
val bound = upperBoundIncreasing(hgd.probability, pexact * gamma, 0, mode)
pupper + hgd.cumulativeProbability(bound - 1)
}

assert(pvalue >= 0d && pvalue <= 1.000000000002)

pvalue
}

def dnorm(x: Double, mu: Double, sigma: Double, logP: Boolean): Double =
Normal.density(x, mu, sigma, logP)

Expand Down
59 changes: 54 additions & 5 deletions hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package is.hail.stats

import is.hail.HailSuite

import is.hail.utils.D_==
import org.testng.annotations.Test

class FisherExactTestSuite extends HailSuite {
Expand All @@ -14,9 +14,58 @@ class FisherExactTestSuite extends HailSuite {

val result = fisherExactTest(a, b, c, d)

assert(math.abs(result(0) - 0.2828) < 1e-4)
assert(math.abs(result(1) - 0.4754059) < 1e-4)
assert(math.abs(result(2) - 0.122593) < 1e-4)
assert(math.abs(result(3) - 1.597972) < 1e-4)
assert(D_==(result(0), 0.2828, 1e-4))
assert(D_==(result(1), 0.4754059, 1e-4))
assert(D_==(result(2), 0.122593, 1e-4))
assert(D_==(result(3), 1.597972, 1e-4))
}

@Test def testPvalue2(): Unit = {
val a = 10
val b = 5
val c = 90
val d = 95

val result = fisherExactTest(a, b, c, d)

assert(D_==(result(0), 0.2828, 1e-4))
}

@Test def test_basic(): Unit = {
// test cases taken from scipy/stats/tests/test_stats.py
var res = fisherExactTestPValueOnly(14500, 20000, 30000, 40000)
assert(D_==(res, 0.01106, 1e-3))
res = fisherExactTestPValueOnly(100, 2, 1000, 5)
assert(D_==(res, 0.1301, 1e-3))
res = fisherExactTestPValueOnly(2, 7, 8, 2)
assert(D_==(res, 0.0230141, 1e-5))
res = fisherExactTestPValueOnly(5, 1, 10, 10)
assert(D_==(res, 0.1973244, 1e-6))
res = fisherExactTestPValueOnly(5, 15, 20, 20)
assert(D_==(res, 0.0958044, 1e-6))
res = fisherExactTestPValueOnly(5, 16, 20, 25)
assert(D_==(res, 0.1725862, 1e-5))
res = fisherExactTestPValueOnly(10, 5, 10, 1)
assert(D_==(res, 0.1973244, 1e-6))
res = fisherExactTestPValueOnly(5, 0, 1, 4)
assert(D_==(res, 0.04761904, 1e-6))
res = fisherExactTestPValueOnly(0, 1, 3, 2)
assert(res == 1.0)
res = fisherExactTestPValueOnly(0, 2, 6, 4)
assert(D_==(res, 0.4545454545))
res = fisherExactTestPValueOnly(2, 7, 8, 2)
assert(D_==(res, 0.0230141, 1e-5))

res = fisherExactTestPValueOnly(6, 37, 108, 200)
assert(D_==(res, 0.005092697748126))
res = fisherExactTestPValueOnly(22, 0, 0, 102)
assert(D_==(res, 7.175066786244549e-25))
res = fisherExactTestPValueOnly(94, 48, 3577, 16988)
assert(D_==(res, 2.069356340993818e-37))
res = fisherExactTestPValueOnly(5829225, 5692693, 5760959, 5760959)
assert(res <= 1e-170)
for ((a, b, c, d) <- Array((0, 0, 5, 10), (5, 10, 0, 0), (0, 5, 0, 10), (5, 0, 10, 0))) {
assert(fisherExactTestPValueOnly(a, b, c, d) == 1.0)
}
}
}