Skip to content

Commit

Permalink
Merge pull request #3 from reibitto/snapshot/dryRun
Browse files Browse the repository at this point in the history
Prepare for new release
  • Loading branch information
reibitto authored Mar 27, 2024
2 parents 008fcf0 + a078d07 commit 52ea166
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
branches:
- master
- main
- snapshot/
- snapshot/*
tags: ["*"]
jobs:
publish:
Expand Down
28 changes: 21 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ bit easier for you.
Add the following to `project/plugins.sbt`:

```scala
addSbtPlugin("com.github.reibitto" % "sbt-test-shards" % "0.1.0")
addSbtPlugin("com.github.reibitto" % "sbt-test-shards" % "0.2.0")
```

## Configuration
Expand Down Expand Up @@ -72,19 +72,33 @@ shardingAlgorithm := ShardingAlgorithm.Balance(
)
```

As you can see, filling this out manually would be tedious. Ideally you'd want to derive
this data structure from a test report. If that's not an option, you could also get away
with only including your slowest test suites in this list and leave the rest to the fallback
sharding algorithm.
As you can see, filling this out manually would be tedious and would require constant maintenance
as you add/remove tests (particularly if the tests are expensive). sbt automatically generates
test report xml files (JUnit-compatible format) when tests are run, and sbt-test-shards can consume
these reports so you don't have to manually manage this yourself. Example usage:

Eventually this plugin will be able to consume test reports itself so that you won't have to
worry about it at all.
```scala
shardingAlgorithm := ShardingAlgorithm.Balance.fromJUnitReports(
Seq(Paths.get(s"path-to-report-files")), // these will usually be located in the `target` folders
shardsInfo = ShardingInfo(testShardCount.value)
)
```

For there to be test reports you have to first run `sbt test` on your entire project. And there's also
the issue that these files won't exist in your CI environment unless you cache/store them somewhere.
I'd recommend storing them remotely somewhere and then pulling them down in CI before running the tests.
And upon successful CI completion, publish the newly generated test reports remotely to keep them up to date.
This can be anywhere such as S3 or even storing them in an artifact as resources and publishing to a private
Maven repo.

### Additional configuration

If you're debugging and want to see logs in CI of which suites are set to run and which
are skipped, you can use `testShardDebug := true`

Also you can run `testDryRun` to see how each suite will be distributed without actually
running all the tests and waiting for them to complete.

## CI Configuration

### GitHub Actions
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/sbttestshards/ShardResult.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package sbttestshards

final case class ShardResult(testShard: Option[Int])
75 changes: 55 additions & 20 deletions src/main/scala/sbttestshards/ShardingAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sbttestshards

import sbttestshards.parsers.FullTestReport
import sbttestshards.parsers.JUnitReportParser

import java.nio.charset.StandardCharsets
Expand All @@ -10,8 +11,19 @@ import scala.util.hashing.MurmurHash3
// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
trait ShardingAlgorithm {

/** Determines whether the specified spec will run on this shard or not. */
def shouldRun(specName: String, shardContext: ShardContext): Boolean
/** Prior test report that can be used by some sharding algorithms to optimize
* balancing tests across different shards.
*/
def priorReport: Option[FullTestReport]

/** Returns the result of whether the specified suite will run on this shard
* or not.
*/
def check(suiteName: String, shardContext: ShardContext): ShardResult

/** Determines whether the specified suite will run on this shard or not. */
def shouldRun(suiteName: String, shardContext: ShardContext): Boolean =
check(suiteName, shardContext).testShard.contains(shardContext.testShard)
}

object ShardingAlgorithm {
Expand All @@ -21,24 +33,37 @@ object ShardingAlgorithm {
*/
final case object SuiteName extends ShardingAlgorithm {

override def shouldRun(specName: String, shardContext: ShardContext): Boolean =
MurmurHash3
.bytesHash(specName.getBytes(StandardCharsets.UTF_8))
.abs % shardContext.testShardCount == shardContext.testShard
def check(suiteName: String, shardContext: ShardContext): ShardResult = {
val testShard = MurmurHash3
.bytesHash(suiteName.getBytes(StandardCharsets.UTF_8))
.abs % shardContext.testShardCount

ShardResult(Some(testShard))
}

def priorReport: Option[FullTestReport] = None
}

/** Will always mark the test to run on this shard. Useful for debugging or
* for fallback algorithms.
*/
final case object Always extends ShardingAlgorithm {
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = true

def check(suiteName: String, shardContext: ShardContext): ShardResult =
ShardResult(Some(shardContext.testShard))

def priorReport: Option[FullTestReport] = None
}

/** Will never mark the test to run on this shard. Useful for debugging or for
* fallback algorithms.
*/
final case object Never extends ShardingAlgorithm {
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false

def check(suiteName: String, shardContext: ShardContext): ShardResult =
ShardResult(None)

def priorReport: Option[FullTestReport] = None
}

object Balance {
Expand All @@ -47,28 +72,33 @@ object ShardingAlgorithm {
reportDirectories: Seq[Path],
shardsInfo: ShardingInfo,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
): Balance =
): Balance = {
val priorReport = JUnitReportParser.parseDirectoriesRecursively(reportDirectories)

ShardingAlgorithm.Balance(
JUnitReportParser.parseDirectoriesRecursively(reportDirectories).testReports.map { r =>
SpecInfo(r.name, Some(Duration.ofMillis((r.timeTaken * 1000).toLong)))
priorReport.testReports.map { r =>
SuiteInfo(r.name, Some(Duration.ofMillis((r.timeTaken * 1000).toLong)))
},
shardsInfo,
fallbackShardingAlgorithm
fallbackShardingAlgorithm,
Some(priorReport)
)
}
}

/** Attempts to balance the shards by execution time so that no one shard
* takes significantly longer to complete than another.
*/
final case class Balance(
specs: Seq[SpecInfo],
suites: Seq[SuiteInfo],
shardsInfo: ShardingInfo,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName,
priorReport: Option[FullTestReport] = None
) extends ShardingAlgorithm {

// TODO: Median might be better here?
private val averageTime: Option[Duration] = {
val allTimeTaken = specs.flatMap(_.timeTaken)
val allTimeTaken = suites.flatMap(_.timeTaken)

allTimeTaken.reduceOption(_.plus(_)).map { d =>
if (allTimeTaken.isEmpty) Duration.ZERO
Expand All @@ -80,7 +110,7 @@ object ShardingAlgorithm {
// is NP-complete, there's a lot of room for improvement with other algorithms. Dynamic programming should be
// possible here.
def distributeEvenly: Map[TestSuiteInfoSimple, Int] = {
val allTests = specs
val allTests = suites
.map(t => TestSuiteInfoSimple(t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration.ZERO))))
.sortBy(_.timeTaken)(Orderings.duration.reverse)

Expand Down Expand Up @@ -109,10 +139,15 @@ object ShardingAlgorithm {
k.name -> v
}

def shouldRun(specName: String, shardContext: ShardContext): Boolean =
bucketMap.get(specName) match {
case Some(bucketIndex) => bucketIndex == shardContext.testShard
case None => fallbackShardingAlgorithm.shouldRun(specName, shardContext)
def check(suiteName: String, shardContext: ShardContext): ShardResult =
bucketMap.get(suiteName) match {
case Some(bucketIndex) =>
ShardResult(Some(bucketIndex))

case None =>
shardContext.logger.warn(s"Using fallback algorithm for $suiteName")

fallbackShardingAlgorithm.check(suiteName, shardContext)
}
}

Expand Down
5 changes: 0 additions & 5 deletions src/main/scala/sbttestshards/SpecBucketItem.scala

This file was deleted.

5 changes: 0 additions & 5 deletions src/main/scala/sbttestshards/SpecInfo.scala

This file was deleted.

5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/SuiteBucketItem.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class SuiteBucketItem(name: String, timeTaken: Duration)
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/SuiteInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class SuiteInfo(name: String, timeTaken: Option[Duration])
53 changes: 44 additions & 9 deletions src/main/scala/sbttestshards/TestShardsPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sbttestshards

import sbt.*
import sbt.Keys.*
import sbttestshards.parsers.FullTestReport

object TestShardsPlugin extends AutoPlugin {

Expand All @@ -10,17 +11,13 @@ object TestShardsPlugin extends AutoPlugin {
val testShardCount = settingKey[Int]("testShardCount")
val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm")
val testShardDebug = settingKey[Boolean]("testShardDebug")
val testDryRun = inputKey[Unit]("testDryRun")
}

import autoImport.*

override def trigger = allRequirements

def stringConfig(key: String, default: String): String = {
val propertyKey = key.replace('_', '.').toLowerCase
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
}

override lazy val projectSettings: Seq[Def.Setting[?]] =
Seq(
testShard := stringConfig("TEST_SHARD", "0").toInt,
Expand All @@ -30,18 +27,56 @@ object TestShardsPlugin extends AutoPlugin {
Test / testOptions += {
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)

Tests.Filter { specName =>
val isInShard = shardingAlgorithm.value.shouldRun(specName, shardContext)
Tests.Filter { suiteName =>
val isInShard = shardingAlgorithm.value.shouldRun(suiteName, shardContext)

if (testShardDebug.value)
if (isInShard)
sLog.value.info(s"`$specName` set to run on this shard (#${testShard.value}).")
sLog.value.info(s"`$suiteName` set to run on this shard (#${testShard.value}).")
else
sLog.value.warn(s"`$specName` skipped because it will run on another shard.")
sLog.value.warn(s"`$suiteName` skipped because it will run on another shard.")

isInShard
}
},
testDryRun := {
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
val logger = shardContext.logger
val algorithm = shardingAlgorithm.value
val priorReport = algorithm.priorReport.getOrElse(FullTestReport.empty)
val sbtSuiteNames = (Test / definedTestNames).value.toSet
val missingSuiteNames = sbtSuiteNames diff priorReport.testReports.map(_.name).toSet

val results = priorReport.testReports.map { suiteReport =>
val shardResult = algorithm.check(suiteReport.name, shardContext)

shardResult.testShard -> suiteReport
}.collect { case (Some(shard), report) => shard -> report }
.groupBy(_._1)

results.toSeq.sortBy(_._1).foreach { case (k, v) =>
val totalTime = BigDecimal(v.map(_._2.timeTaken).sum).setScale(3, BigDecimal.RoundingMode.HALF_UP)

logger.info(s"[${moduleName.value}] Shard $k expected to take $totalTime s")

v.map(_._2).foreach { suiteReport =>
logger.info(s"* ${suiteReport.name} = ${suiteReport.timeTaken} s")
}
}

if (missingSuiteNames.nonEmpty) {
logger.warn(s"Detected ${missingSuiteNames.size} suites that don't have a test report")

missingSuiteNames.foreach { s =>
logger.warn(s"- $s")
}
}
}
)

private def stringConfig(key: String, default: String): String = {
val propertyKey = key.replace('_', '.').toLowerCase
sys.props.get(propertyKey).orElse(sys.env.get(key)).getOrElse(default)
}

}
24 changes: 14 additions & 10 deletions src/main/scala/sbttestshards/parsers/JUnitReportParser.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package sbttestshards.parsers

import java.nio.file.{Files, Path, Paths}
import java.nio.file.Files
import java.nio.file.Path
import scala.jdk.CollectionConverters.*
import scala.xml.XML

final case class FullTestReport(testReports: Seq[SpecTestReport]) {
def specCount: Int = testReports.length
final case class FullTestReport(testReports: Seq[SuiteReport]) {
def suiteCount: Int = testReports.length

def testCount: Int = testReports.map(_.testCount).sum

Expand All @@ -18,7 +19,11 @@ final case class FullTestReport(testReports: Seq[SpecTestReport]) {
def ++(other: FullTestReport): FullTestReport = FullTestReport(testReports ++ other.testReports)
}

final case class SpecTestReport(
object FullTestReport {
def empty: FullTestReport = FullTestReport(Seq.empty)
}

final case class SuiteReport(
name: String,
testCount: Int,
errorCount: Int,
Expand Down Expand Up @@ -66,10 +71,10 @@ object JUnitReportParser {
}
)

def parseReport(reportFile: Path): SpecTestReport = {
def parseReport(reportFile: Path): SuiteReport = {
val xml = XML.loadFile(reportFile.toFile)

val specName = xml \@ "name"
val suiteName = xml \@ "name"
val testCount = (xml \@ "tests").toInt
val errorCount = (xml \@ "errors").toInt
val failureCount = (xml \@ "failures").toInt
Expand All @@ -83,15 +88,14 @@ object JUnitReportParser {
val testName = (node \@ "name").trim

Some(testName)
} else {
} else
None
}
}.collect { case Some(testName) =>
testName
}

SpecTestReport(
name = specName,
SuiteReport(
name = suiteName,
testCount = testCount,
errorCount = errorCount,
failureCount = failureCount,
Expand Down
Loading

0 comments on commit 52ea166

Please sign in to comment.