diff --git a/src/main/scala/kinesis/mock/api/ListShardsRequest.scala b/src/main/scala/kinesis/mock/api/ListShardsRequest.scala index 91312ab4..e4a17534 100644 --- a/src/main/scala/kinesis/mock/api/ListShardsRequest.scala +++ b/src/main/scala/kinesis/mock/api/ListShardsRequest.scala @@ -64,7 +64,7 @@ final case class ListShardsRequest( ListShardsRequest .createNextToken(streamName, shards.last.shardId.shardId) ) - ListShardsResponse(nextToken, shards) + ListShardsResponse(nextToken, shards.map(ShardSummary.fromShard)) }) } case (_, None, _, _, Some(sName)) => @@ -150,7 +150,7 @@ final case class ListShardsRequest( ListShardsRequest .createNextToken(sName, shards.last.shardId.shardId) ) - ListShardsResponse(nextToken, shards) + ListShardsResponse(nextToken, shards.map(ShardSummary.fromShard)) }) ) case (_, None, _, _, None) => diff --git a/src/main/scala/kinesis/mock/api/ListShardsResponse.scala b/src/main/scala/kinesis/mock/api/ListShardsResponse.scala index 59981575..07fe11ec 100644 --- a/src/main/scala/kinesis/mock/api/ListShardsResponse.scala +++ b/src/main/scala/kinesis/mock/api/ListShardsResponse.scala @@ -8,7 +8,7 @@ import kinesis.mock.models._ final case class ListShardsResponse( nextToken: Option[String], - shards: List[Shard] + shards: List[ShardSummary] ) object ListShardsResponse { @@ -18,7 +18,7 @@ object ListShardsResponse { x => for { nextToken <- x.downField("NextToken").as[Option[String]] - shards <- x.downField("Shards").as[List[Shard]] + shards <- x.downField("Shards").as[List[ShardSummary]] } yield ListShardsResponse(nextToken, shards) implicit val listShardsResponseEq: Eq[ListShardsResponse] = (x, y) => x.nextToken == y.nextToken && x.shards === y.shards diff --git a/src/main/scala/kinesis/mock/models/ShardSummary.scala b/src/main/scala/kinesis/mock/models/ShardSummary.scala index 65cefbf5..4b3a5bf2 100644 --- a/src/main/scala/kinesis/mock/models/ShardSummary.scala +++ b/src/main/scala/kinesis/mock/models/ShardSummary.scala @@ -9,7 +9,9 @@ final case class ShardSummary( parentShardId: Option[String], sequenceNumberRange: SequenceNumberRange, shardId: String -) +) { + val isOpen: Boolean = sequenceNumberRange.endingSequenceNumber.isEmpty +} object ShardSummary { def fromShard(shard: Shard): ShardSummary = ShardSummary( diff --git a/src/main/scala/kinesis/mock/validations/CommonValidations.scala b/src/main/scala/kinesis/mock/validations/CommonValidations.scala index 48f9039d..6fd9b643 100644 --- a/src/main/scala/kinesis/mock/validations/CommonValidations.scala +++ b/src/main/scala/kinesis/mock/validations/CommonValidations.scala @@ -316,36 +316,40 @@ object CommonValidations { explicitHashKey: Option[String], stream: StreamData ): ValidatedNel[KinesisMockException, (Shard, List[KinesisRecord])] = { - Try( - Md5Utils.computeMD5Hash( - explicitHashKey - .getOrElse(partitionKey) - .getBytes(StandardCharsets.US_ASCII) - ) - ).toValidated - .leftMap(e => - NonEmptyList.one( - InvalidArgumentException( - s"Could not compute MD5 hash, ${e.getMessage}" - ) - ) - ) - .andThen { hashBytes => - val hashInt = BigInt.apply(1, hashBytes) - - stream.shards - .collectFirst { - case (shard, data) - if hashInt >= shard.hashKeyRange.startingHashKey && hashInt <= shard.hashKeyRange.endingHashKey => - (shard, data) - } match { - case None => - InvalidArgumentException( - "Could not find shard for partitionKey" - ).invalidNel - case Some(x) => Valid(x) + (explicitHashKey match { + case Some(ehk) => + val hash = BigInt(ehk) + if (hash < Shard.minHashKey || hash > Shard.maxHashKey) { + InvalidArgumentException("ExplicitHashKey is not valid").invalidNel + } else { + hash.validNel } + case None => + Try( + Md5Utils.computeMD5Hash(partitionKey.getBytes(StandardCharsets.UTF_8)) + ).toValidated.bimap( + e => + NonEmptyList.one( + InvalidArgumentException( + s"Could not compute MD5 hash, ${e.getMessage}" + ) + ), + x => BigInt(1, x) + ) + }).andThen { hashInt => + stream.shards + .collectFirst { + case (shard, data) + if shard.isOpen && hashInt >= shard.hashKeyRange.startingHashKey && hashInt <= shard.hashKeyRange.endingHashKey => + (shard, data) + } match { + case None => + InvalidArgumentException( + "Could not find shard for partitionKey" + ).invalidNel + case Some(x) => Valid(x) } + } } def validateExplicitHashKey( diff --git a/src/test/scala/kinesis/mock/api/ListShardsTests.scala b/src/test/scala/kinesis/mock/api/ListShardsTests.scala index 07732c84..06cb7547 100644 --- a/src/test/scala/kinesis/mock/api/ListShardsTests.scala +++ b/src/test/scala/kinesis/mock/api/ListShardsTests.scala @@ -25,7 +25,7 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => streams.streams.get(streamName).exists { s => - s.shards.keys.toList == response.shards + s.shards.keys.toList.map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\nres: $res" }) @@ -50,11 +50,15 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && paginatedRes.isValid && res.exists { response => streams.streams.get(streamName).exists { s => - s.shards.keys.toList.take(50) == response.shards + s.shards.keys.toList + .take(50) + .map(ShardSummary.fromShard) == response.shards } } && paginatedRes.exists { response => streams.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(50) == response.shards + s.shards.keys.toList + .takeRight(50) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"resCount: ${res.map(_.shards.length)}\n" + @@ -86,7 +90,9 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => streams.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(89) == response.shards + s.shards.keys.toList + .takeRight(89) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\nres: $res" }) @@ -130,7 +136,9 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => updated.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(95) == response.shards + s.shards.keys.toList + .takeRight(95) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"res: ${res.map(_.shards.length)}\n" + @@ -176,7 +184,7 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => updated.streams.get(streamName).exists { s => - s.shards.keys.toList == response.shards + s.shards.keys.toList.map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"res: ${res.map(_.shards.length)}\n" + @@ -225,7 +233,9 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => updated.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(95) == response.shards + s.shards.keys.toList + .takeRight(95) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"res: ${res.map(_.shards.length)}\n" + @@ -258,7 +268,9 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => streams.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(95) == response.shards + s.shards.keys.toList + .takeRight(95) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"resLen: ${res.map(_.shards.length)}\n" + @@ -327,7 +339,9 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => updated.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(95) == response.shards + s.shards.keys.toList + .takeRight(95) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"res: ${res.map(_.shards.length)}\n" + @@ -395,7 +409,9 @@ class ListShardsTests extends munit.ScalaCheckSuite { (res.isValid && res.exists { response => updated.streams.get(streamName).exists { s => - s.shards.keys.toList.takeRight(95) == response.shards + s.shards.keys.toList + .takeRight(95) + .map(ShardSummary.fromShard) == response.shards } }) :| s"req: $req\n" + s"res: ${res.map(_.shards.length)}\n" + diff --git a/src/test/scala/kinesis/mock/cache/DescribeStreamTests.scala b/src/test/scala/kinesis/mock/cache/DescribeStreamTests.scala index ef4b39ae..7e10d2c2 100644 --- a/src/test/scala/kinesis/mock/cache/DescribeStreamTests.scala +++ b/src/test/scala/kinesis/mock/cache/DescribeStreamTests.scala @@ -48,7 +48,7 @@ class DescribeStreamTests context ) .rethrow - .map(x => x.shards.map(ShardSummary.fromShard)) + .map(x => x.shards) expected = StreamDescription( Some(EncryptionType.NONE), List(ShardLevelMetrics(List.empty)), diff --git a/src/test/scala/kinesis/mock/cache/GetRecordsTests.scala b/src/test/scala/kinesis/mock/cache/GetRecordsTests.scala index f3be9107..bb000f8f 100644 --- a/src/test/scala/kinesis/mock/cache/GetRecordsTests.scala +++ b/src/test/scala/kinesis/mock/cache/GetRecordsTests.scala @@ -52,7 +52,7 @@ class GetRecordsTests shardIterator <- cache .getShardIterator( GetShardIteratorRequest( - shard.shardId.shardId, + shard.shardId, ShardIteratorType.TRIM_HORIZON, None, streamName, diff --git a/src/test/scala/kinesis/mock/cache/GetShardIteratorTests.scala b/src/test/scala/kinesis/mock/cache/GetShardIteratorTests.scala index 3d5bb56b..b5ac255b 100644 --- a/src/test/scala/kinesis/mock/cache/GetShardIteratorTests.scala +++ b/src/test/scala/kinesis/mock/cache/GetShardIteratorTests.scala @@ -42,7 +42,7 @@ class GetShardIteratorTests res <- cache .getShardIterator( GetShardIteratorRequest( - shard.shardId.shardId, + shard.shardId, ShardIteratorType.TRIM_HORIZON, None, streamName, diff --git a/src/test/scala/kinesis/mock/cache/PutRecordTests.scala b/src/test/scala/kinesis/mock/cache/PutRecordTests.scala index 1dae306d..b185d811 100644 --- a/src/test/scala/kinesis/mock/cache/PutRecordTests.scala +++ b/src/test/scala/kinesis/mock/cache/PutRecordTests.scala @@ -52,7 +52,7 @@ class PutRecordTests shardIterator <- cache .getShardIterator( GetShardIteratorRequest( - shard.shardId.shardId, + shard.shardId, ShardIteratorType.TRIM_HORIZON, None, streamName, diff --git a/src/test/scala/kinesis/mock/cache/PutRecordsTests.scala b/src/test/scala/kinesis/mock/cache/PutRecordsTests.scala index 3f1983b6..2e5e9f97 100644 --- a/src/test/scala/kinesis/mock/cache/PutRecordsTests.scala +++ b/src/test/scala/kinesis/mock/cache/PutRecordsTests.scala @@ -52,7 +52,7 @@ class PutRecordsTests shardIterator <- cache .getShardIterator( GetShardIteratorRequest( - shard.shardId.shardId, + shard.shardId, ShardIteratorType.TRIM_HORIZON, None, streamName, diff --git a/src/test/scala/kinesis/mock/cache/SplitShardTests.scala b/src/test/scala/kinesis/mock/cache/SplitShardTests.scala index b8f95d5e..06039bae 100644 --- a/src/test/scala/kinesis/mock/cache/SplitShardTests.scala +++ b/src/test/scala/kinesis/mock/cache/SplitShardTests.scala @@ -48,7 +48,7 @@ class SplitShardTests .splitShard( SplitShardRequest( (shardToSplit.hashKeyRange.endingHashKey / BigInt(2)).toString, - shardToSplit.shardId.shardId, + shardToSplit.shardId, streamName ), context @@ -68,7 +68,7 @@ class SplitShardTests checkStream2.streamDescriptionSummary.streamStatus == StreamStatus.ACTIVE && checkShards.shards.count(!_.isOpen) == 1 && checkShards.shards.count(shard => - shard.parentShardId.contains(shardToSplit.shardId.shardId) + shard.parentShardId.contains(shardToSplit.shardId) ) == 2 && checkShards.shards.length == 7, s"${checkShards.shards.mkString("\n\t")}\n" + s"$checkStream1\n" + diff --git a/src/test/scala/kinesis/mock/instances/arbitrary.scala b/src/test/scala/kinesis/mock/instances/arbitrary.scala index bf422afc..091a8af3 100644 --- a/src/test/scala/kinesis/mock/instances/arbitrary.scala +++ b/src/test/scala/kinesis/mock/instances/arbitrary.scala @@ -181,6 +181,9 @@ object arbitrary { shard ) + def shardSummaryGen(shardIndex: Int): Gen[ShardSummary] = + shardGen(shardIndex).map(ShardSummary.fromShard) + implicit val shardArbitrary: Arbitrary[Shard] = Arbitrary( Gen.choose(100, 1000).flatMap(index => shardGen(index)) ) @@ -610,8 +613,8 @@ object arbitrary { implicit val listShardsResponseArb: Arbitrary[ListShardsResponse] = Arbitrary( for { nextToken <- Gen.option(nextTokenGen(None)) - shards <- Gen.sequence[List[Shard], Shard]( - List.range(0, 100).map(x => shardGen(x)) + shards <- Gen.sequence[List[ShardSummary], ShardSummary]( + List.range(0, 100).map(x => shardSummaryGen(x)) ) } yield ListShardsResponse(nextToken, shards) ) @@ -687,7 +690,8 @@ object arbitrary { ) ) - val explicitHashKeyGen: Gen[String] = RegexpGen.from("0|([1-9]\\d{0,38})") + val explicitHashKeyGen: Gen[String] = + Gen.choose(Shard.minHashKey, Shard.maxHashKey).map(_.toString) val partitionKeyGen: Gen[String] = Gen.choose(1, 256).flatMap(size => Gen.stringOfN(size, Gen.alphaNumChar))