From 51c8e0deb81f5125e115d42ef25bd3ca6f84d713 Mon Sep 17 00:00:00 2001 From: zeotuan Date: Sat, 12 Oct 2024 14:34:26 +1100 Subject: [PATCH 1/2] add randn with mean and variance --- quinn/__init__.py | 5 +---- quinn/math.py | 22 ++++++++++++++++++++++ tests/test_math.py | 20 ++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/quinn/__init__.py b/quinn/__init__.py index 2d442c7..0138ad2 100644 --- a/quinn/__init__.py +++ b/quinn/__init__.py @@ -52,10 +52,7 @@ week_end_date, week_start_date, ) -from quinn.math import ( - rand_laplace, - rand_range, -) +from quinn.math import rand_laplace, rand_range, randn from quinn.schema_helpers import print_schema_as_code from quinn.split_columns import split_col from quinn.transformations import ( diff --git a/quinn/math.py b/quinn/math.py index 61f2b63..ff6410b 100644 --- a/quinn/math.py +++ b/quinn/math.py @@ -70,6 +70,28 @@ def rand_range( return minimum + (maximum - minimum) * u +def randn( + mean: Union[float, Column], + variance: Union[float, Column], + seed: Optional[int] = None, +) -> Column: + """Generate a column with independent and identically distributed (i.i.d.) samples from + the standard normal distribution with given `mean` and `variance`.. + + :param mean: Mean of the normal distribution of the random numbers + :param variance: variance of the normal distribution of the random numbers + :param seed: random seed value (optional, default None) + :returns: column with random numbers + """ + if not isinstance(mean, Column): + mean = F.lit(mean) + + if not isinstance(variance, Column): + variance = F.lit(variance) + + return F.rand(seed) * F.sqrt(variance) + mean + + def div_or_else( cola: Column, colb: Column, diff --git a/tests/test_math.py b/tests/test_math.py index 87c814a..cdde6cd 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -38,3 +38,23 @@ def test_rand_range(): uniform_max = stats["max"] assert lower_bound <= uniform_min <= uniform_max <= upper_bound + + +def test_randn(): + mean = 1.0 + variance = 2.0 + stats = ( + spark.range(1000) + .select(quinn.randn(mean, variance).alias("rand_normal")) + .agg( + F.mean("rand_normal").alias("agg_mean"), + F.variance("rand_normal").alias("agg_variance"), + ) + .first() + ) + + agg_mean = stats["agg_mean"] + agg_variance = stats["agg_variance"] + + assert agg_mean - mean <= 0.1 + assert agg_variance - variance <= 0.1 From e3432235792ceb482eca63a7f123bc8136694804 Mon Sep 17 00:00:00 2001 From: zeotuan Date: Sat, 12 Oct 2024 14:42:20 +1100 Subject: [PATCH 2/2] use correct random function --- quinn/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quinn/math.py b/quinn/math.py index ff6410b..980b62c 100644 --- a/quinn/math.py +++ b/quinn/math.py @@ -89,7 +89,7 @@ def randn( if not isinstance(variance, Column): variance = F.lit(variance) - return F.rand(seed) * F.sqrt(variance) + mean + return F.randn(seed) * F.sqrt(variance) + mean def div_or_else(