Skip to content

Commit

Permalink
Update the bucketFinder function, so users can customize if the lowes…
Browse files Browse the repository at this point in the history
…t bound with use lt or lte and the highest bound can use gt or gte
  • Loading branch information
MrPowers committed Jun 22, 2018
1 parent beb1a90 commit a652038
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 115 deletions.
71 changes: 34 additions & 37 deletions src/main/scala/com/github/mrpowers/spark/daria/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -330,47 +330,44 @@ object functions {
def bucketFinder(
col: Column,
buckets: Array[(Any, Any)],
inclusiveBoundries: Boolean = false
inclusiveBoundries: Boolean = false,
lowestBoundLte: Boolean = false,
hightestBoundGte: Boolean = false
): Column = {

val b = if (inclusiveBoundries) {
buckets.map { res: (Any, Any) =>
when(
col.isNull,
lit(null)
val inclusiveBoundriesCol = lit(inclusiveBoundries)
val lowerBoundLteCol = lit(lowestBoundLte)
val upperBoundGteCol = lit(hightestBoundGte)

val b = buckets.map { res: (Any, Any) =>
when(
col.isNull,
lit(null)
)
.when(
lowerBoundLteCol === false && lit(res._1).isNull && lit(res._2).isNotNull && col < lit(res._2),
lit(s"<${res._2}")
)
.when(
lit(res._1).isNull && lit(res._2).isNotNull && col < lit(res._2),
lit(s"<${res._2}")
)
.when(
lit(res._1).isNotNull && lit(res._2).isNull && col > lit(res._1),
lit(s">${res._1}")
)
.when(
col.geq(res._1) && col.leq(res._2),
lit(s"${res._1}-${res._2}")
)
}
} else {
buckets.map { res: (Any, Any) =>
when(
col.isNull,
lit(null)
.when(
lowerBoundLteCol === true && lit(res._1).isNull && lit(res._2).isNotNull && col <= lit(res._2),
lit(s"<=${res._2}")
)
.when(
upperBoundGteCol === false && lit(res._1).isNotNull && lit(res._2).isNull && col > lit(res._1),
lit(s">${res._1}")
)
.when(
upperBoundGteCol === true && lit(res._1).isNotNull && lit(res._2).isNull && col >= lit(res._1),
lit(s">=${res._1}")
)
.when(
inclusiveBoundriesCol === true && col.geq(res._1) && col.leq(res._2),
lit(s"${res._1}-${res._2}")
)
.when(
inclusiveBoundriesCol === false && col.between(res._1, res._2),
lit(s"${res._1}-${res._2}")
)
.when(
lit(res._1).isNull && lit(res._2).isNotNull && col < lit(res._2),
lit(s"<${res._2}")
)
.when(
lit(res._1).isNotNull && lit(res._2).isNull && col > lit(res._1),
lit(s">${res._1}")
)
.when(
col.between(res._1, res._2),
lit(s"${res._1}-${res._2}")
)
}
}

coalesce(b: _*)
Expand Down
195 changes: 117 additions & 78 deletions src/test/scala/com/github/mrpowers/spark/daria/sql/FunctionsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -455,90 +455,129 @@ object FunctionsTest

}

'bucketFinder - {

"finds what bucket a column value belongs in" - {

val df = spark.createDF(
List(
// works for standard use cases
(24, "20-30"),
(45, "30-60"),
// works with range boundries
(10, "10-20"),
(20, "10-20"),
// works with less than / greater than
(3, "<10"),
(99, ">70"),
// works for numbers that don't fall in any buckets
(65, null),
// works with null
(null, null)
),
List(
("some_num", IntegerType, true),
("expected", StringType, true)
)
).withColumn(
"bucket",
functions.bucketFinder(
col("some_num"),
Array(
(null, 10),
(10, 20),
(20, 30),
(30, 60),
(70, null)
)
}

'bucketFinder - {

"finds what bucket a column value belongs in" - {

val df = spark.createDF(
List(
// works for standard use cases
(24, "20-30"),
(45, "30-60"),
// works with range boundries
(10, "10-20"),
(20, "10-20"),
// works with less than / greater than
(3, "<10"),
(99, ">70"),
// works for numbers that don't fall in any buckets
(65, null),
// works with null
(null, null)
),
List(
("some_num", IntegerType, true),
("expected", StringType, true)
)
).withColumn(
"bucket",
functions.bucketFinder(
col("some_num"),
Array(
(null, 10),
(10, 20),
(20, 30),
(30, 60),
(70, null)
)
)
)

assertColumnEquality(df, "expected", "bucket")

}

"can use inclusive bucket ranges" - {

val df = spark.createDF(
List(
// works for standard use cases
(15, "10-20"),
// works with range boundries
(10, "10-20"),
(20, "10-20"),
(50, "41-50"),
(40, "31-40"),
// works with less than / greater than
(9, "<10"),
(72, ">70"),
// works for numbers that don't fall in any bucket
(65, null),
// works with null
(null, null)
),
List(
("some_num", IntegerType, true),
("expected", StringType, true)
)
).withColumn(
"bucket",
functions.bucketFinder(
col("some_num"),
Array(
(null, 10),
(10, 20),
(21, 30),
(31, 40),
(41, 50),
(70, null)
),
inclusiveBoundries = true
)
assertColumnEquality(df, "expected", "bucket")

}

"can use inclusive bucket ranges" - {

val df = spark.createDF(
List(
// works for standard use cases
(15, "10-20"),
// works with range boundries
(10, "10-20"),
(20, "10-20"),
(50, "41-50"),
(40, "31-40"),
// works with less than / greater than
(9, "<10"),
(72, ">70"),
// works for numbers that don't fall in any bucket
(65, null),
// works with null
(null, null)
),
List(
("some_num", IntegerType, true),
("expected", StringType, true)
)
).withColumn(
"bucket",
functions.bucketFinder(
col("some_num"),
Array(
(null, 10),
(10, 20),
(21, 30),
(31, 40),
(41, 50),
(70, null)
),
inclusiveBoundries = true
)
)

assertColumnEquality(df, "expected", "bucket")

}

"works with a highly customized use case" - {

assertColumnEquality(df, "expected", "bucket")
val df = spark.createDF(
List(
(0, "<1"),
(1, "1-1"),
(2, "2-4"),
(3, "2-4"),
(4, "2-4"),
(10, "5-74"),
(75, ">=75"),
(90, ">=75"),
(null, null)
),
List(
("some_num", IntegerType, true),
("expected", StringType, true)
)
).withColumn(
"bucket",
functions.bucketFinder(
col("some_num"),
Array(
(null, 1),
(1, 1),
(2, 4),
(5, 74),
(75, null)
),
inclusiveBoundries = true,
lowestBoundLte = false,
hightestBoundGte = true
)
)

}
assertColumnEquality(df, "expected", "bucket")

}

Expand Down

0 comments on commit a652038

Please sign in to comment.