Skip to content

Commit

Permalink
Refactor partitionByKey to get clearer error
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Nov 21, 2024
1 parent 49f43b0 commit 5c1c7e5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,8 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
* @group collection
*/
def partitionByKey[U](partitionKeys: Set[U])(f: T => U): Map[U, SCollection[T]] = {
val partitionKeysIndexed = partitionKeys.toIndexedSeq

partitionKeysIndexed
.zip(partition(partitionKeys.size, (t: T) => partitionKeysIndexed.indexOf(f(t))))
.toMap
val partitions = partitionKeys.zipWithIndex.toMap
partitionKeys.zip(partition(partitionKeys.size, x => partitions(f(x)))).toMap
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import scala.jdk.CollectionConverters._
import com.spotify.scio.coders.{Beam, Coder, MaterializedCoder}
import com.spotify.scio.options.ScioOptions
import com.spotify.scio.schemas.Schema
import org.apache.beam.sdk.Pipeline.PipelineExecutionException
import org.apache.beam.sdk.coders.{NullableCoder, StringUtf8Coder}

import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -249,6 +250,16 @@ class SCollectionTest extends PipelineSpec {
m("b") should containInAnyOrder(Seq("b4", "b5"))
m("c") should containInAnyOrder(Seq("c6"))
}

val e = the[PipelineExecutionException] thrownBy {
runWithContext { sc =>
sc
.parallelize(Seq("x"))
.partitionByKey(Set("a", "b", "c"))(_.substring(0, 1))
}
}
e.getCause shouldBe a[NoSuchElementException]
e.getCause.getMessage shouldBe "key not found: x"
}

it should "support hashPartition() based on Object.hashCode()" in {
Expand Down

0 comments on commit 5c1c7e5

Please sign in to comment.