Skip to content

Commit

Permalink
alsotestwith
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Nov 22, 2024
1 parent d45ea2a commit feaca20
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,6 @@ trait AlsoTestWithChangelogCheckpointingEnabled
}
}

def testWithEncodingTypes(
testName: String,
testTags: Tag*)
(testBody: => Any): Unit = {
Seq("unsaferow", "avro").foreach { encoding =>
super.test(testName + s" (encoding = $encoding)", testTags: _*) {
// in case tests have any code that needs to execute before every test
// super.beforeEach()
withSQLConf(
SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key ->
encoding) {
testBody
}
// in case tests have any code that needs to execute after every test
// super.afterEach()
}
}
}

def testWithColumnFamilies(
testName: String,
testMode: TestMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf

case class InputRow(key: String, action: String, value: String)
Expand Down Expand Up @@ -127,10 +127,11 @@ class ToggleSaveAndEmitProcessor
}

class TransformWithListStateSuite extends StreamTest
with AlsoTestWithChangelogCheckpointingEnabled {
with AlsoTestWithChangelogCheckpointingEnabled
with AlsoTestWithEncodingTypes {
import testImplicits._

testWithEncodingTypes("test appending null value in list state throw exception") {
test("test appending null value in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -150,7 +151,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test putting null value in list state throw exception") {
test("test putting null value in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -170,7 +171,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test putting null list in list state throw exception") {
test("test putting null list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -190,7 +191,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test appending null list in list state throw exception") {
test("test appending null list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -210,7 +211,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test putting empty list in list state throw exception") {
test("test putting empty list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -230,7 +231,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test appending empty list in list state throw exception") {
test("test appending empty list in list state throw exception") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -250,7 +251,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test list state correctness") {
test("test list state correctness") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand Down Expand Up @@ -307,7 +308,7 @@ class TransformWithListStateSuite extends StreamTest
}
}

testWithEncodingTypes("test ValueState And ListState in Processor") {
test("test ValueState And ListState in Processor") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest {

override def getStateTTLMetricName: String = "numListStateWithTTLVars"

testWithEncodingTypes("verify iterator works with expired values in beginning of list") {
test("verify iterator works with expired values in beginning of list") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
Expand Down Expand Up @@ -195,7 +195,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest {
// ascending order of TTL by stopping the query, setting the new TTL, and restarting
// the query to check that the expired elements in the middle or end of the list
// are not returned.
testWithEncodingTypes("verify iterator works with expired values in middle of list") {
test("verify iterator works with expired values in middle of list") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
Expand Down Expand Up @@ -343,7 +343,7 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest {
}
}

testWithEncodingTypes("verify iterator works with expired values in end of list") {
test("verify iterator works with expired values in end of list") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf

case class InputMapRow(key: String, action: String, value: (String, String))
Expand Down Expand Up @@ -81,7 +81,8 @@ class TestMapStateProcessor
* operators such as transformWithState.
*/
class TransformWithMapStateSuite extends StreamTest
with AlsoTestWithChangelogCheckpointingEnabled {
with AlsoTestWithChangelogCheckpointingEnabled
with AlsoTestWithEncodingTypes {
import testImplicits._

private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = {
Expand Down Expand Up @@ -110,7 +111,7 @@ class TransformWithMapStateSuite extends StreamTest
}
}

testWithEncodingTypes("Test retrieving value with non-existing user key") {
test("Test retrieving value with non-existing user key") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -129,12 +130,12 @@ class TransformWithMapStateSuite extends StreamTest
}

Seq("getValue", "containsKey", "updateValue", "removeKey").foreach { mapImplFunc =>
testWithEncodingTypes(s"Test $mapImplFunc with null user key") {
test(s"Test $mapImplFunc with null user key") {
testMapStateWithNullUserKey(InputMapRow("k1", mapImplFunc, (null, "")))
}
}

testWithEncodingTypes("Test put value with null value") {
test("Test put value with null value") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {

Expand All @@ -158,7 +159,7 @@ class TransformWithMapStateSuite extends StreamTest
}
}

testWithEncodingTypes("Test map state correctness") {
test("Test map state correctness") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName) {
val inputData = MemoryStream[InputMapRow]
Expand Down Expand Up @@ -219,7 +220,7 @@ class TransformWithMapStateSuite extends StreamTest
}
}

testWithEncodingTypes("transformWithMapState - batch should succeed") {
test("transformWithMapState - batch should succeed") {
val inputData = Seq(
InputMapRow("k1", "updateValue", ("v1", "10")),
InputMapRow("k1", "getValue", ("v1", "")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest {

override def getStateTTLMetricName: String = "numMapStateWithTTLVars"

testWithEncodingTypes("validate state is evicted with multiple user keys") {
test("validate state is evicted with multiple user keys") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
Expand Down Expand Up @@ -224,7 +224,7 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest {
}
}

testWithEncodingTypes("verify iterator doesn't return expired keys") {
test("verify iterator doesn't return expired keys") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
Expand Down
Loading

0 comments on commit feaca20

Please sign in to comment.