From 2c0d45afc697ffafd6e735e6100e7c20ae1ea1fc Mon Sep 17 00:00:00 2001 From: milin-k Date: Tue, 17 Jul 2018 15:06:40 +0530 Subject: [PATCH] bucketFinder - fixed issue with inclusive boundary --- .../mrpowers/spark/daria/sql/functions.scala | 4 +-- .../spark/daria/sql/FunctionsTest.scala | 35 +++++++++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala index fe61e254..d71fa1c5 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala @@ -361,11 +361,11 @@ object functions { lit(s">=${res._1}") ) .when( - inclusiveBoundriesCol === true && col.geq(res._1) && col.leq(res._2), + inclusiveBoundriesCol === true && col.between(res._1, res._2), lit(s"${res._1}-${res._2}") ) .when( - inclusiveBoundriesCol === false && col.between(res._1, res._2), + inclusiveBoundriesCol === false && col.gt(res._1) && col.lt(res._2), lit(s"${res._1}-${res._2}") ) } diff --git a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala index 295bdfec..fc9da17f 100644 --- a/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala @@ -491,7 +491,8 @@ object FunctionsTest (20, 30), (30, 60), (70, null) - ) + ), + inclusiveBoundries = true ) ) @@ -581,8 +582,36 @@ object FunctionsTest } - } + "works with a highly customized use case" - { - } + val df = spark.createDF( + List( + (0, "<1"), + (10, "1-11"), + (11, ">=11") + ), + List( + ("some_num", IntegerType, true), + ("expected", StringType, true) + ) + ).withColumn( + "bucket", + functions.bucketFinder( + col("some_num"), + Array( + (null, 1), + (1, 11), + (11, null) + ), + inclusiveBoundries = false, + hightestBoundGte = true + ) + ) + + assertColumnEquality(df, "expected", "bucket") + } + + } + } }